{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9626b5a9-adb5-42fd-aa59-44009c214349",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rcParams\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pathlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad6e88ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "color1 = \"#1f77b4\"\n",
    "color2 = \"#fe6100\"\n",
    "\n",
    "sns.set_context(\"paper\")\n",
    "sns.set_style(\"white\")\n",
    "\n",
    "rcParams[\"font.family\"] = \"sans-serif\"\n",
    "# set font size\n",
    "fs1 = 28\n",
    "fs2 = 23\n",
    "rcParams[\"font.size\"] = fs1\n",
    "rcParams[\"axes.labelsize\"] = fs1\n",
    "rcParams[\"axes.titlesize\"] = fs1\n",
    "rcParams[\"xtick.labelsize\"] = fs1\n",
    "rcParams[\"ytick.labelsize\"] = fs1\n",
    "rcParams[\"legend.fontsize\"] = fs2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea7a94c",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_ALGORITHM_NAMES = [\n",
    "    \"lloyd\",\n",
    "    \"hartigan\",\n",
    "    \"pca\",\n",
    "    \"pca_split\",\n",
    "    \"sdp\",\n",
    "    \"spectral\",\n",
    "]\n",
    "\n",
    "ALGORITHM_LABELS = {\n",
    "    \"lloyd\": \"Lloyd\",\n",
    "    \"hartigan\": \"Hartigan\",\n",
    "    \"bhartigan\": \"B-Hartigan\",\n",
    "    \"mbhartigan\": \"MB-Hartigan\",\n",
    "    \"pca\": \"PCA + Lloyd\",\n",
    "    \"pca_split\": \"PCA + Split\",\n",
    "    \"sdp\": \"SDP\",\n",
    "    \"spectral\": \"Spectral\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc918635",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_loss_metric(loss, loss_ref):\n",
    "    return np.where(\n",
    "        np.abs(loss - loss_ref) / np.abs(loss_ref) < 1e-6,\n",
    "        0,\n",
    "        np.where(loss - loss_ref < 0, 1, -1),\n",
    "    )\n",
    "\n",
    "\n",
    "def compute_loss_metric_for_results(results, algorithm_names=None):\n",
    "    if algorithm_names is None:\n",
    "        algorithm_names = DEFAULT_ALGORITHM_NAMES\n",
    "    loss_metrics = {}\n",
    "    for method in algorithm_names:\n",
    "        loss_metrics[method] = compute_loss_metric(\n",
    "            results[method][\"loss\"], results[\"true_partition\"][\"loss\"]\n",
    "        ).mean(-1)\n",
    "    return loss_metrics\n",
    "\n",
    "\n",
    "def compute_nmi_metrics_for_results(results, algorithm_names=None):\n",
    "    if algorithm_names is None:\n",
    "        algorithm_names = DEFAULT_ALGORITHM_NAMES\n",
    "\n",
    "    nmi_metrics = {}\n",
    "    for method in algorithm_names:\n",
    "        nmi_metrics[method] = results[method][\"nmi\"].mean(-1)\n",
    "    return nmi_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfc217f8-376a-4ba6-926c-002b39eab4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_nmi(nmi_matrix, ax):\n",
    "    im = ax.imshow(nmi_matrix.T, origin=\"lower\", vmin=0, vmax=1, cmap=\"cividis\")\n",
    "    return im\n",
    "\n",
    "\n",
    "def plot_loss(loss_matrix, ax):\n",
    "    im = ax.imshow(loss_matrix.T, origin=\"lower\", vmin=-1, vmax=1, cmap=\"bwr_r\")\n",
    "    return im\n",
    "\n",
    "\n",
    "def plot_metrics(\n",
    "    metrics_dict,\n",
    "    results_dict,\n",
    "    metric_type,\n",
    "    algorithm_names,\n",
    "    cluster_sizes,\n",
    "    fig_fname=None,\n",
    "    fig_suptitle=None,\n",
    "):\n",
    "    if metric_type not in [\"nmi\", \"loss\"]:\n",
    "        raise ValueError(f\"Unknown metric type {metric_type}\")\n",
    "\n",
    "    subplot_fn = plot_nmi if metric_type == \"nmi\" else plot_loss\n",
    "    dimension_vals = results_dict[\"k2_kpp\"][\"dimension_vals\"]\n",
    "\n",
    "    # removing smallest 4 noise variance values for better visualization\n",
    "    noise_variance_vals = results_dict[\"k2_kpp\"][\"noise_variance_vals\"][4:]\n",
    "\n",
    "    x_tick_indices = np.array([1, 7, 13, 19])\n",
    "    x_tick_labels = [f\"{d:.0f}\" for d in np.log10(dimension_vals[x_tick_indices])]\n",
    "    x_tick_labels = [rf\"$10^{d}$\" for d in x_tick_labels]\n",
    "\n",
    "    y_tick_indices = np.array([0, 3, 7, 11, 15])\n",
    "    y_tick_labels = noise_variance_vals[y_tick_indices].round(1)\n",
    "\n",
    "    n_algos = len(algorithm_names)\n",
    "    fig, ax = plt.subplots(\n",
    "        len(cluster_sizes),\n",
    "        n_algos,\n",
    "        figsize=(5 * n_algos, 4 * len(cluster_sizes)),\n",
    "        sharex=True,\n",
    "        sharey=True,\n",
    "        layout=\"compressed\",\n",
    "    )\n",
    "    if n_algos == 1:\n",
    "        ax = ax[:, np.newaxis]\n",
    "    if len(cluster_sizes) == 1:\n",
    "        ax = ax[np.newaxis, :]\n",
    "    ims = []\n",
    "    for i, k in enumerate(cluster_sizes):\n",
    "        for j in range(n_algos):\n",
    "            if algorithm_names[j] == \"hartigan\":\n",
    "                dict_key = f\"k{k}_randpar\"\n",
    "                im = subplot_fn(\n",
    "                    metrics_dict[dict_key][algorithm_names[j]][:, 4:], ax[i, j]\n",
    "                )\n",
    "            else:\n",
    "                dict_key = f\"k{k}_kpp\"\n",
    "                im = subplot_fn(\n",
    "                    metrics_dict[dict_key][algorithm_names[j]][:, 4:], ax[i, j]\n",
    "                )\n",
    "        ims.append(im)\n",
    "\n",
    "    for i in range(n_algos):\n",
    "        ax[-1, i].set_xticks(x_tick_indices, x_tick_labels)\n",
    "\n",
    "    for i in range(len(cluster_sizes)):\n",
    "        ax[i, 0].set_yticks(y_tick_indices, y_tick_labels)\n",
    "        cbar = plt.colorbar(ims[i])\n",
    "        if metric_type == \"loss\":\n",
    "            cbar.set_ticks([-1, 0, 1])\n",
    "            cbar.set_ticklabels([\"GT\", \"Tie\", \"Clustering\"])\n",
    "\n",
    "    if fig_suptitle is not None:\n",
    "        fig.suptitle(fig_suptitle)\n",
    "\n",
    "    if fig_fname is not None:\n",
    "        pathlib.Path(fig_fname).parent.mkdir(parents=True, exist_ok=True)\n",
    "        plt.savefig(fig_fname, bbox_inches=\"tight\", dpi=600)\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8ff2599-064a-41fb-bc17-7fceef195e2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_results(filename):\n",
    "    results = dict(np.load(filename, allow_pickle=True))\n",
    "\n",
    "    for key in results.keys():  # relevant_keys:\n",
    "        if isinstance(results[key], np.ndarray) and results[key].size == 1:\n",
    "            results[key] = results[key].item()\n",
    "    return results\n",
    "\n",
    "\n",
    "def load_results_and_merge(n_clusters, init):\n",
    "    results = load_results(\n",
    "        f\"./outputs/experiments_kmeans_{n_clusters}clusters{init}.npz\"\n",
    "    )\n",
    "    results_tmp = load_results(\n",
    "        f\"./outputs/experiments_spectral_{n_clusters}clusters.npz\"\n",
    "    )\n",
    "    results.update(results_tmp)\n",
    "    results_tmp = load_results(f\"./outputs/experiments_sdp_{n_clusters}clusters.npz\")\n",
    "    results.update(results_tmp)\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d0fd0dc",
   "metadata": {},
   "source": [
    "# Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2183bfc1-ce82-4063-b1f2-5347c39a571a",
   "metadata": {},
   "outputs": [],
   "source": [
    "algorithm_names = [\n",
    "    \"lloyd\",\n",
    "    \"hartigan\",\n",
    "    \"pca\",\n",
    "    \"pca_split\",\n",
    "    \"sdp\",\n",
    "    \"spectral\",\n",
    "]\n",
    "\n",
    "\n",
    "results_dict = {}\n",
    "loss_metrics = {}\n",
    "nmi_metrics = {}\n",
    "for init in [\"_random\", \"_kpp\", \"_randpar\"]:\n",
    "    for n_clusters in [2, 5, 10]:\n",
    "        results_dict[f\"k{n_clusters}{init}\"] = load_results_and_merge(n_clusters, init)\n",
    "        loss_metrics[f\"k{n_clusters}{init}\"] = compute_loss_metric_for_results(\n",
    "            results_dict[f\"k{n_clusters}{init}\"], algorithm_names\n",
    "        )\n",
    "        nmi_metrics[f\"k{n_clusters}{init}\"] = compute_nmi_metrics_for_results(\n",
    "            results_dict[f\"k{n_clusters}{init}\"], algorithm_names=algorithm_names\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b90838c",
   "metadata": {},
   "source": [
    "# NMI and Loss vs K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d837a54",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metrics(\n",
    "    nmi_metrics,\n",
    "    results_dict,\n",
    "    metric_type=\"nmi\",\n",
    "    algorithm_names=[\"lloyd\", \"hartigan\", \"pca\", \"sdp\", \"spectral\"],\n",
    "    cluster_sizes=[2, 5, 10],\n",
    "    fig_fname=\"plots/figures_vs_k/figure_nmi_raw.svg\",\n",
    ")\n",
    "\n",
    "plot_metrics(\n",
    "    loss_metrics,\n",
    "    results_dict,\n",
    "    metric_type=\"loss\",\n",
    "    algorithm_names=[\"lloyd\", \"hartigan\", \"pca\", \"sdp\", \"spectral\"],\n",
    "    cluster_sizes=[2, 5, 10],\n",
    "    fig_fname=\"plots/figures_vs_k/figure_loss_raw.svg\",\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a70e2ad",
   "metadata": {},
   "source": [
    "# NMI and Loss vs Init (Lloyd and Hartigan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93246639",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_metrics_vs_init(\n",
    "    metrics_dict,\n",
    "    results_dict,\n",
    "    metric_type,\n",
    "    fig_fname=None,\n",
    "    fig_suptitle=None,\n",
    "):\n",
    "    if metric_type not in [\"nmi\", \"loss\"]:\n",
    "        raise ValueError(f\"Unknown metric type {metric_type}\")\n",
    "\n",
    "    subplot_fn = plot_nmi if metric_type == \"nmi\" else plot_loss\n",
    "    dimension_vals = results_dict[\"k2_kpp\"][\"dimension_vals\"]\n",
    "\n",
    "    # removing smallest 4 noise variance values for better visualization\n",
    "    noise_variance_vals = results_dict[\"k2_kpp\"][\"noise_variance_vals\"][4:]\n",
    "\n",
    "    x_tick_indices = np.array([1, 7, 13, 19])\n",
    "    x_tick_labels = [f\"{d:.0f}\" for d in np.log10(dimension_vals[x_tick_indices])]\n",
    "    x_tick_labels = [rf\"$10^{d}$\" for d in x_tick_labels]\n",
    "\n",
    "    y_tick_indices = np.array([0, 3, 7, 11, 15])\n",
    "    y_tick_labels = noise_variance_vals[y_tick_indices].round(1)\n",
    "\n",
    "    fig, ax = plt.subplots(\n",
    "        3, 6, figsize=(30, 12), sharex=True, sharey=True, layout=\"compressed\"\n",
    "    )\n",
    "\n",
    "    ims = []\n",
    "    inits = [\n",
    "        \"_randpar\",\n",
    "        \"_random\",\n",
    "        \"_kpp\",\n",
    "    ]\n",
    "    algorithm_names = [\n",
    "        \"lloyd\",\n",
    "        \"hartigan\",\n",
    "    ]\n",
    "    for i, k in enumerate([2, 5, 10]):\n",
    "        counter = 0\n",
    "        for j1 in range(3):\n",
    "            for j2 in range(2):\n",
    "                dict_key = f\"k{k}{inits[j1]}\"\n",
    "                im = subplot_fn(\n",
    "                    metrics_dict[dict_key][algorithm_names[j2]][:, 4:], ax[i, counter]\n",
    "                )\n",
    "                counter += 1\n",
    "\n",
    "        ims.append(im)\n",
    "\n",
    "    for i in range(6):\n",
    "        ax[-1, i].set_xticks(x_tick_indices, x_tick_labels)\n",
    "\n",
    "    for i in range(3):\n",
    "        ax[i, 0].set_yticks(y_tick_indices, y_tick_labels)\n",
    "        cbar = plt.colorbar(ims[i])\n",
    "        if metric_type == \"loss\":\n",
    "            cbar.set_ticks([-1, 0, 1])\n",
    "            cbar.set_ticklabels([\"GT\", \"Tie\", \"Clustering\"])\n",
    "\n",
    "    if fig_suptitle is not None:\n",
    "        fig.suptitle(fig_suptitle)\n",
    "\n",
    "    if fig_fname is not None:\n",
    "        pathlib.Path(fig_fname).parent.mkdir(parents=True, exist_ok=True)\n",
    "        plt.savefig(fig_fname, bbox_inches=\"tight\", dpi=600)\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a81229d",
   "metadata": {},
   "outputs": [],
   "source": [
    "alg_names_for_plots = [\"lloyd\", \"hartigan\"]\n",
    "plot_metrics_vs_init(\n",
    "    nmi_metrics,\n",
    "    results_dict,\n",
    "    metric_type=\"nmi\",\n",
    "    fig_fname=\"plots/figures_vs_init/figure_nmi_raw.svg\",\n",
    ")\n",
    "\n",
    "plot_metrics_vs_init(\n",
    "    loss_metrics,\n",
    "    results_dict,\n",
    "    metric_type=\"loss\",\n",
    "    fig_fname=\"plots/figures_vs_init/figure_loss_raw.svg\",\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c38c7fb9",
   "metadata": {},
   "source": [
    "# Plot Iterations Lloyd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43ebd630",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_iterations(results, ax, vmin, vmax):\n",
    "    iterations_matrix = results[\"lloyd\"][\"n_iter\"].mean(-1)[:, 4:]\n",
    "    im = ax.imshow(\n",
    "        iterations_matrix.T, origin=\"lower\", cmap=\"cividis\", vmin=vmin, vmax=vmax\n",
    "    )\n",
    "\n",
    "    return im\n",
    "\n",
    "\n",
    "dimension_vals = results_dict[\"k10_random\"][\"dimension_vals\"]\n",
    "noise_variance_vals = results_dict[\"k10_random\"][\"noise_variance_vals\"][4:]\n",
    "\n",
    "x_tick_indices = np.array([1, 7, 13, 19])\n",
    "x_tick_labels = [f\"{d:.0f}\" for d in np.log10(dimension_vals[x_tick_indices])]\n",
    "x_tick_labels = [rf\"$10^{d}$\" for d in x_tick_labels]\n",
    "\n",
    "y_tick_indices = np.array([0, 3, 7, 11, 15])\n",
    "y_tick_labels = noise_variance_vals[y_tick_indices].round(1)\n",
    "\n",
    "fig, ax = plt.subplots(\n",
    "    3, 3, figsize=(15, 12), layout=\"compressed\", sharex=True, sharey=True\n",
    ")\n",
    "\n",
    "plot_iterations(results_dict[\"k2_randpar\"], ax[0, 0], vmin=1, vmax=3)\n",
    "plot_iterations(results_dict[\"k2_random\"], ax[0, 1], vmin=1, vmax=3)\n",
    "im1 = plot_iterations(results_dict[\"k2_kpp\"], ax[0, 2], vmin=1, vmax=3)\n",
    "\n",
    "plot_iterations(results_dict[\"k5_randpar\"], ax[1, 0], vmin=1, vmax=3)\n",
    "plot_iterations(results_dict[\"k5_random\"], ax[1, 1], vmin=1, vmax=3)\n",
    "im2 = plot_iterations(results_dict[\"k5_kpp\"], ax[1, 2], vmin=1, vmax=3)\n",
    "\n",
    "plot_iterations(results_dict[\"k10_randpar\"], ax[2, 0], vmin=1, vmax=3)\n",
    "plot_iterations(results_dict[\"k10_random\"], ax[2, 1], vmin=1, vmax=3)\n",
    "im3 = plot_iterations(results_dict[\"k10_kpp\"], ax[2, 2], vmin=1, vmax=3)\n",
    "\n",
    "cbar1 = plt.colorbar(im1, ax=ax[0, :])\n",
    "cbar2 = plt.colorbar(im2, ax=ax[1, :])\n",
    "cbar3 = plt.colorbar(im3, ax=ax[2, :])\n",
    "\n",
    "for cbar in [cbar1, cbar2, cbar3]:\n",
    "    cbar.set_ticks([1, 2, 3])\n",
    "    cbar.set_ticklabels([\"1\", \"2\", \"3+\"])\n",
    "\n",
    "for i in range(3):\n",
    "    ax[i, 0].set_yticks(y_tick_indices, y_tick_labels)\n",
    "    ax[-1, i].set_xticks(x_tick_indices, x_tick_labels)\n",
    "\n",
    "fig_fname = \"plots/figure_lloyd_iterations_raw.svg\"\n",
    "pathlib.Path(fig_fname).parent.mkdir(parents=True, exist_ok=True)\n",
    "plt.savefig(fig_fname, bbox_inches=\"tight\", dpi=600)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "487d8a7a",
   "metadata": {},
   "source": [
    "# PCA vs Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03ca596f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dimension_vals = results_dict[\"k10_random\"][\"dimension_vals\"]\n",
    "noise_variance_vals = results_dict[\"k10_random\"][\"noise_variance_vals\"][4:]\n",
    "\n",
    "x_tick_indices = np.array([1, 7, 13, 19])\n",
    "x_tick_labels = [f\"{d:.0f}\" for d in np.log10(dimension_vals[x_tick_indices])]\n",
    "x_tick_labels = [rf\"$10^{d}$\" for d in x_tick_labels]\n",
    "\n",
    "y_tick_indices = np.array([0, 3, 7, 11, 15])\n",
    "y_tick_labels = noise_variance_vals[y_tick_indices].round(1)\n",
    "\n",
    "fig, ax = plt.subplots(\n",
    "    2, 2, figsize=(15, 12), layout=\"compressed\", sharex=False, sharey=True\n",
    ")\n",
    "\n",
    "\n",
    "plot_nmi(nmi_metrics[\"k2_kpp\"][\"lloyd\"][:, 4:], ax[0, 0])\n",
    "im1 = plot_nmi(nmi_metrics[\"k2_kpp\"][\"pca_split\"][:, 4:], ax[0, 1])\n",
    "\n",
    "plot_loss(loss_metrics[\"k2_kpp\"][\"lloyd\"][:, 4:], ax[1, 0])\n",
    "im2 = plot_loss(loss_metrics[\"k2_kpp\"][\"pca_split\"][:, 4:], ax[1, 1])\n",
    "\n",
    "\n",
    "cbar1 = plt.colorbar(im1, ax=ax[0, :])\n",
    "cbar1.set_label(\"NMI\", rotation=270, labelpad=25)\n",
    "\n",
    "cbar2 = plt.colorbar(im2, ax=ax[1, :])\n",
    "cbar2.set_ticks([-1, 0, 1])\n",
    "cbar2.set_ticklabels([\"GT\", \"Tie\", \"Clustering\"])\n",
    "cbar2.set_label(\"Partition with better loss\", rotation=270, labelpad=30)\n",
    "\n",
    "for i in range(2):\n",
    "    ax[i, 0].set_title(\"Lloyd\")\n",
    "    ax[i, 1].set_title(\"PCA + Split\")\n",
    "\n",
    "for i in range(2):\n",
    "    ax[i, 0].set_yticks(y_tick_indices, y_tick_labels)\n",
    "    ax[0, i].set_xticks(x_tick_indices, x_tick_labels)\n",
    "    ax[1, i].set_xticks(x_tick_indices, x_tick_labels)\n",
    "\n",
    "fig_fname = \"plots/figure_pca_split_vs_lloyd_raw.svg\"\n",
    "pathlib.Path(fig_fname).parent.mkdir(parents=True, exist_ok=True)\n",
    "plt.savefig(fig_fname, bbox_inches=\"tight\", dpi=600)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "befc8494",
   "metadata": {},
   "source": [
    "# Comparing Methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e22a42b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_nmi_diff(nmi1, nmi2):\n",
    "    return (nmi1 - nmi2).mean(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfb07396",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_nmi_difference(nmi_method1, nmi_method2, ax):\n",
    "    diff = compare_nmi_diff(nmi_method1, nmi_method2)[:, 4:]\n",
    "\n",
    "    print(np.min(diff), np.max(diff))\n",
    "    im = ax.imshow(diff.T, origin=\"lower\", cmap=\"bwr\", vmin=-0.3, vmax=0.3)\n",
    "\n",
    "    ax.set_xticks(x_tick_indices, x_tick_labels)\n",
    "    ax.set_yticks(y_tick_indices, y_tick_labels)\n",
    "    return im\n",
    "\n",
    "\n",
    "dimension_vals = results_dict[\"k2_kpp\"][\"dimension_vals\"]\n",
    "noise_variance_vals = results_dict[\"k2_kpp\"][\"noise_variance_vals\"][4:]\n",
    "\n",
    "x_tick_indices = np.array([1, 7, 13, 19])\n",
    "x_tick_labels = [f\"{d:.0f}\" for d in np.log10(dimension_vals[x_tick_indices])]\n",
    "x_tick_labels = [rf\"$10^{d}$\" for d in x_tick_labels]\n",
    "\n",
    "y_tick_indices = np.array([0, 3, 7, 11, 15])\n",
    "y_tick_labels = noise_variance_vals[y_tick_indices].round(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7308861",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, lab in zip([2, 5, 10], [\"A\", \"B\", \"C\"]):\n",
    "    fig, ax = plt.subplots(\n",
    "        3, 3, figsize=(15, 14), layout=\"compressed\", sharex=False, sharey=False\n",
    "    )\n",
    "\n",
    "    # Lloyd vs everyone else\n",
    "    plot_nmi_difference(\n",
    "        results_dict[f\"k{k}_kpp\"][\"lloyd\"][\"nmi\"],\n",
    "        results_dict[f\"k{k}_randpar\"][\"hartigan\"][\"nmi\"],\n",
    "        ax[0, 0],\n",
    "    )\n",
    "    ax[0, 0].set_title(\"Hartigan\")\n",
    "\n",
    "    plot_nmi_difference(\n",
    "        results_dict[f\"k{k}_kpp\"][\"lloyd\"][\"nmi\"],\n",
    "        results_dict[f\"k{k}_kpp\"][\"sdp\"][\"nmi\"],\n",
    "        ax[0, 1],\n",
    "    )\n",
    "    ax[0, 1].set_title(\"SDP\")\n",
    "\n",
    "    im1 = plot_nmi_difference(\n",
    "        results_dict[f\"k{k}_kpp\"][\"lloyd\"][\"nmi\"],\n",
    "        results_dict[f\"k{k}_kpp\"][\"spectral\"][\"nmi\"],\n",
    "        ax[0, 2],\n",
    "    )\n",
    "    ax[0, 2].set_title(\"Spectral\")\n",
    "\n",
    "    ax[1, 0].set_axis_off()\n",
    "    plot_nmi_difference(\n",
    "        results_dict[f\"k{k}_randpar\"][\"hartigan\"][\"nmi\"],\n",
    "        results_dict[f\"k{k}_randpar\"][\"sdp\"][\"nmi\"],\n",
    "        ax[1, 1],\n",
    "    )\n",
    "\n",
    "    im2 = plot_nmi_difference(\n",
    "        results_dict[f\"k{k}_randpar\"][\"hartigan\"][\"nmi\"],\n",
    "        results_dict[f\"k{k}_randpar\"][\"spectral\"][\"nmi\"],\n",
    "        ax[1, 2],\n",
    "    )\n",
    "\n",
    "    ax[2, 0].set_axis_off()\n",
    "    ax[2, 1].set_axis_off()\n",
    "    im3 = plot_nmi_difference(\n",
    "        results_dict[f\"k{k}_randpar\"][\"sdp\"][\"nmi\"],\n",
    "        results_dict[f\"k{k}_randpar\"][\"spectral\"][\"nmi\"],\n",
    "        ax[2, 2],\n",
    "    )\n",
    "\n",
    "    ax[0, 1].set_xticks([])\n",
    "    ax[0, 1].set_yticks([])\n",
    "    ax[0, 2].set_xticks([])\n",
    "    ax[0, 2].set_yticks([])\n",
    "\n",
    "    ax[1, 2].set_xticks([])\n",
    "    ax[1, 2].set_yticks([])\n",
    "\n",
    "    cbar1 = plt.colorbar(\n",
    "        im1,\n",
    "    )  # label=r\"$\\leftarrow$ Other   Lloyd $\\rightarrow$\")\n",
    "    cbar1.ax.set_yticks([-0.2, 0, 0.2])\n",
    "    cbar1.ax.set_yticklabels([r\"$-0.2$\", \"0\", r\"$0.2$\"])\n",
    "\n",
    "    cbar2 = plt.colorbar(\n",
    "        im2,\n",
    "    )  # label=r\"$\\leftarrow$ Other   Hartigan $\\rightarrow$\")\n",
    "    cbar2.ax.set_yticks([-0.2, 0, 0.2])\n",
    "    cbar2.ax.set_yticklabels([r\"$-0.2$\", \"0\", r\"$0.2$\"])\n",
    "\n",
    "    cbar3 = plt.colorbar(\n",
    "        im3,\n",
    "    )  # label=r\"$\\leftarrow$ Other   SDP $\\rightarrow$\")\n",
    "    cbar3.ax.set_yticks([-0.2, 0, 0.2])\n",
    "    cbar3.ax.set_yticklabels([r\"$-0.2$\", \"0\", r\"$0.2$\"])\n",
    "\n",
    "    fig.suptitle(f\"({lab}) K = {k}\")\n",
    "    fig_fname = f\"plots/figures_comparison/figure_vs_k{k}.svg\"\n",
    "    pathlib.Path(fig_fname).parent.mkdir(parents=True, exist_ok=True)\n",
    "    fig.savefig(fig_fname, bbox_inches=\"tight\", dpi=600)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "obs-on-kmeans-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
