{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "from utils.io_utils import load_multiple_res\n",
    "from utils.utils import get_path\n",
    "from utils.pd_utils import get_persistent_feature_id, get_persistent_cycle, compute_outlier_scores\n",
    "from utils.toydata_utils import get_toy_data\n",
    "from utils.fig_utils import full_dist_to_color, all_full_dists, dist_to_color, full_dist_to_print, plot_edges_on_scatter\n",
    "from vis_utils.plot import plot_scatter\n",
    "from vis_utils.utils import load_dict\n",
    "from persim import plot_diagrams\n",
    "import os\n",
    "import numpy as np\n",
    "from sklearn.decomposition import PCA\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "814abc81ca7a729a"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "2168d746d1c013dd"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# show all colors\n",
    "fig, ax = plt.subplots(figsize=(5.5, 2.5))\n",
    "y = np.ones_like(len(full_dist_to_color))\n",
    "\n",
    "\n",
    "bar_width = 0.2\n",
    "sep = 1.5 * bar_width\n",
    "\n",
    "\n",
    "x_vals = [0]\n",
    "dist_strs = []\n",
    "for i, (distance, dist_kwargs) in enumerate(all_full_dists.items()):\n",
    "    x_val = x_vals[-1] + sep\n",
    "    for j, full_dist in enumerate(all_full_dists[distance]):\n",
    "        x_val += bar_width\n",
    "        x_vals.append(x_val)\n",
    "        dist_strs.append(full_dist)\n",
    "        ax.bar(x_val, 1, bar_width, label=full_dist, color = full_dist_to_color[full_dist])\n",
    "ax.set_xticks(x_vals[1:])\n",
    "_ = ax.set_xticklabels([full_dist_to_print[dist_str].replace(\"\\n\", \" \") for dist_str in dist_strs], rotation=90)\n",
    "#ax.set_xticklabels(dist_strs, rotation=90)\n",
    "#fig.savefig(os.path.join(fig_path, \"all_colors.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "63a010e2f9e9fb8e"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Figure 1"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "8c1fbe1b18b68fd8"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "distances = {\n",
    "    \"euclidean\": [{}],\n",
    "    \"eff_res\": [\n",
    "        {\"corrected\": True, \"weighted\": False, \"k\": 100, \"disconnect\":True},\n",
    "        ]\n",
    "}\n",
    "embd_dim = 50\n",
    "sigma = 0.25\n",
    "sigmas = np.linspace(0.0, 0.35, 29)\n",
    "sigmas = np.array([np.format_float_positional(sigma, precision=4, unique=True, trim='0') for sigma in sigmas]).astype(float)\n",
    "seeds = [0, 1, 2]\n",
    "n = 1000\n",
    "seed = 1"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "71ca6a07faca7212"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load runs for toy sphere with euclidean and eff res setting\n",
    "all_res = load_multiple_res(datasets=\"toy_circle\", distances=distances, root_path=root_path, n=n, embd_dims=embd_dim, sigmas=sigmas, seeds=seeds, n_threads=10)\n",
    "\n",
    "res_eucl = all_res[\"euclidean\"][\"euclidean\"][sigma]\n",
    "res_eff = all_res[\"eff_res\"][\"eff_res_corrected_True_weighted_False_k_100_disconnect_True\"][sigma]\n",
    "\n",
    "# compute outlier scores\n",
    "outlier_scores = compute_outlier_scores(dgms=all_res, n_features=1, dim=1)\n",
    "outlier_scores_eucl = outlier_scores[\"euclidean\"][\"euclidean\"]\n",
    "outlier_scores_eff = outlier_scores[\"eff_res\"][\"eff_res_corrected_True_weighted_False_k_100_disconnect_True\"]\n",
    "\n",
    "# load circle data\n",
    "data_circle = np.array([get_toy_data(dataset=\"toy_circle\", n=n, gaussian={\"sigma\":sigma}, seed=seed) for seed in seeds] )\n",
    "\n",
    "\n",
    "# load UMAP and tSNE\n",
    "umap_file_name = f\"umap_n_1000_d_50_ortho_gauss_sigma_{sigma}_k_15_metric_euclidean_epochs_750_seed_{seed}_min_dist_0.1_init_pca.pkl\"\n",
    "with open(os.path.join(root_path, \"toy_circle\", umap_file_name), \"rb\") as f:\n",
    "    umapper = pickle.load(f)\n",
    "umap_embd = umapper.embedding_\n",
    "\n",
    "tsne_file_name = f\"tsne_n_1000_d_50_ortho_gauss_sigma_{sigma}_perplexity_30_n_epochs_500_n_early_epochs_250_seed_{seed}_init_pca_rescale_True.pkl\"\n",
    "\n",
    "tsne_data = load_dict(os.path.join(root_path, \"toy_circle\", tsne_file_name))\n",
    "tsne_embd = tsne_data[\"embds\"][-1]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "c5a853598b44f82c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "mosaic = \"\"\"\n",
    "aabd\n",
    "aabd\n",
    "aabd\n",
    "aacd\n",
    "efcd\n",
    "efcd\n",
    "\"\"\"\n",
    "fig, ax_dict = plt.subplot_mosaic(mosaic,\n",
    "                                  height_ratios=(1, 1, 1, 1, 1, 1),\n",
    "                                  width_ratios = (0.9, 0.9, 0.8, 1.5),\n",
    "                                  figsize=(5.5, 2.),\n",
    "                                  )\n",
    "\n",
    "fig.get_layout_engine().set(w_pad=4 / 72, h_pad=0, hspace=0.02,\n",
    "                            wspace=0.02)\n",
    "\n",
    "# panel a)\n",
    "## plot ring\n",
    "pca = PCA(2).fit_transform(data_circle[seed])\n",
    "mask = pca[:, 1] < pca[:, 1].max()\n",
    "plot_scatter(ax=ax_dict[\"a\"], x=pca[mask], s=2, y=\"k\", scalebar=False, alpha=1.0)\n",
    "ax_dict[\"a\"].set_title(f\"Representative loops on 2D PCA\\n\", pad=3)\n",
    "\n",
    "\n",
    "ax_dict[\"a\"].text(\n",
    "    0.1,\n",
    "    1.02,\n",
    "    \"a\\n\",\n",
    "        transform=ax_dict[\"a\"].transAxes,\n",
    "    ha=\"left\",\n",
    "    va=\"bottom\",\n",
    "    fontweight=\"bold\",\n",
    "    fontsize=7\n",
    ")\n",
    "\n",
    "## plot loops\n",
    "pca = PCA(2).fit_transform(data_circle[seed])@ np.array([[1, 0], [0, -1]]) # mirror for nicer figure layout\n",
    "loop_eucl = get_persistent_cycle(res_eucl[seed], m=1, dim=1, mode=\"additive\")\n",
    "loop_eff = get_persistent_cycle(res_eff[seed], m=1, dim=1, mode=\"additive\")\n",
    "plot_edges_on_scatter(ax=ax_dict[\"a\"], x=pca, edge_idx=loop_eucl, color=dist_to_color[\"euclidean\"], linewidth=1)\n",
    "plot_edges_on_scatter(ax=ax_dict[\"a\"], x=pca, edge_idx=loop_eff, color=dist_to_color[\"eff_res\"], linewidth=1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# panel b (persistences)\n",
    "loop_id_eucl = get_persistent_feature_id(res_eucl[seed], m=1, dim=1, mode=\"additive\")\n",
    "loop_id_eff = get_persistent_feature_id(res_eff[seed], m=1, dim=1, mode=\"additive\")\n",
    "\n",
    "plot_diagrams(res_eucl[seed][\"dgms\"], ax=ax_dict[\"b\"], plot_only=[1], size=4, color=dist_to_color[\"euclidean\"], colormap=style_file)\n",
    "ax_dict[\"b\"].scatter(*res_eucl[seed][\"dgms\"][1][loop_id_eucl].T, marker = \"x\", color=dist_to_color[\"euclidean\"], s=5)\n",
    "ax_dict[\"b\"].legend().set_visible(False)\n",
    "\n",
    "ax_dict[\"b\"].set_title(\"Persistence\\n$\\sigma = 0.25$\")\n",
    "ax_dict[\"b\"].set_title(\n",
    "    \"b\\n\",\n",
    "    ha=\"right\",\n",
    "    loc=\"left\",\n",
    "    fontweight=\"bold\",\n",
    ")\n",
    "ax_dict[\"b\"].set_xlabel(\"\")\n",
    "ax_dict[\"b\"].set_xticklabels(\"\")\n",
    "ax_dict[\"b\"].set_xticks([])\n",
    "ax_dict[\"b\"].set_yticklabels(\"\")\n",
    "ax_dict[\"b\"].set_yticks([])\n",
    "\n",
    "text = ax_dict[\"b\"].text(s=\"Euclidean\",\n",
    "                  x=0.4,\n",
    "                  y=0.1,\n",
    "                transform=ax_dict[\"b\"].transAxes,)\n",
    "\n",
    "# arrows to representatives\n",
    "annot1 = plt.Annotation(\n",
    "    \"\",\n",
    "    xy=(1.2, 0.8), \n",
    "    xycoords=ax_dict[\"a\"].transData,\n",
    "    xytext=(res_eucl[seed][\"dgms\"][1][loop_id_eucl]), \n",
    "    textcoords=ax_dict[\"b\"].transData,\n",
    "    arrowprops=dict(arrowstyle=\"->\", linewidth=0.5, connectionstyle=\"angle3,angleA=80,angleB=10\"),\n",
    ")\n",
    "ax_dict[\"b\"].add_artist(annot1)\n",
    "\n",
    "annot2 = plt.Annotation(\n",
    "    \"\",\n",
    "    xy=(1.7, -0.32), \n",
    "    xycoords=ax_dict[\"a\"].transData,\n",
    "    xytext=(res_eff[seed][\"dgms\"][1][loop_id_eff]), \n",
    "    textcoords=ax_dict[\"c\"].transData,\n",
    "    arrowprops=dict(arrowstyle=\"->\", \n",
    "                    linewidth=0.5,\n",
    "                    connectionstyle=\"angle3,angleA=-90,angleB=-10\", \n",
    "                    color=dist_to_color[\"eff_res\"]),\n",
    ")\n",
    "ax_dict[\"c\"].add_artist(annot2)\n",
    "\n",
    "\n",
    "# panel b bottom: PD for eff res\n",
    "plot_diagrams(res_eff[seed][\"dgms\"],\n",
    "              ax=ax_dict[\"c\"],\n",
    "              plot_only=[1], \n",
    "              size=4, \n",
    "              color=dist_to_color[\"eff_res\"], \n",
    "              colormap=style_file # necessary bc plot_diagrams uses the colormap as mpl style\n",
    "              )\n",
    "ax_dict[\"c\"].scatter(*res_eff[seed][\"dgms\"][1][loop_id_eff].T,\n",
    "                     marker = \"x\", \n",
    "                     color=dist_to_color[\"eff_res\"],\n",
    "                     s=5)\n",
    "ax_dict[\"c\"].legend().set_visible(False)\n",
    "\n",
    "ax_dict[\"c\"].set_xticklabels(\"\")\n",
    "ax_dict[\"c\"].set_xticks([])\n",
    "ax_dict[\"c\"].set_yticklabels(\"\")\n",
    "ax_dict[\"c\"].set_yticks([])\n",
    "\n",
    "ax_dict[\"c\"].text(\n",
    "    s=\"Effective\\nresistance\",\n",
    "    x=0.4,\n",
    "    y=0.1,\n",
    "    c=dist_to_color[\"eff_res\"],\n",
    "    transform=ax_dict[\"c\"].transAxes,\n",
    ")\n",
    "\n",
    "\n",
    "# plot outlier scores in panel c\n",
    "ax_dict[\"d\"].plot(sigmas, outlier_scores_eucl.mean(1), color=\"k\", label=\"Euclidean\")\n",
    "ax_dict[\"d\"].fill_between(sigmas, \n",
    "                          outlier_scores_eucl.mean(1)+outlier_scores_eucl.std(1), \n",
    "                          outlier_scores_eucl.mean(1)-outlier_scores_eucl.std(1), \n",
    "                          color=\"k\",\n",
    "                          alpha=0.2,\n",
    "                          edgecolor=None,\n",
    "                          clip_on=False)\n",
    "ax_dict[\"d\"].plot(sigmas, outlier_scores_eff.mean(1), color=dist_to_color[\"eff_res\"], label=\"Effective\\nresistance\")\n",
    "ax_dict[\"d\"].fill_between(sigmas, \n",
    "                          outlier_scores_eff.mean(1)+outlier_scores_eff.std(1), \n",
    "                          outlier_scores_eff.mean(1)-outlier_scores_eff.std(1), \n",
    "                          color=dist_to_color[\"eff_res\"],\n",
    "                          alpha=0.2,\n",
    "                          edgecolor=None,\n",
    "                          clip_on=False)\n",
    "\n",
    "ax_dict[\"d\"].set_ylim(0.0, 1.0)\n",
    "ax_dict[\"d\"].set_xlim(0.0, 0.35)\n",
    "ax_dict[\"d\"].legend(loc=\"lower left\", frameon=False)\n",
    "ax_dict[\"d\"].set_ylabel(\"Loop detection score\")\n",
    "ax_dict[\"d\"].set_xlabel(r\"Noise std $\\sigma$\")\n",
    "ax_dict[\"d\"].set_title(\n",
    "    \"c\\n\",\n",
    "    loc=\"left\",\n",
    "    fontweight=\"bold\",\n",
    "    ha=\"right\",\n",
    ")\n",
    "ax_dict[\"d\"].set_title(\"Loop detection\\n\")\n",
    "\n",
    "ax_dict[\"d\"].axvline(sigma, linestyle=\"dotted\", c=\"k\")\n",
    "ax_dict[\"d\"].set_xticks([0.0, 0.1, 0.2, 0.25, 0.3])\n",
    "ax_dict[\"d\"].set_xticklabels([0.0, 0.1, 0.2, 0.25, 0.3])\n",
    "\n",
    "\n",
    "# panel d UMAP\n",
    "plot_scatter(ax=ax_dict[\"e\"], x=umap_embd, s=2, y=\"k\", scalebar=False, alpha=1.0)\n",
    "ax_dict[\"e\"].set_title(\"UMAP\")\n",
    "ax_dict[\"e\"].text(\n",
    "    0.2,\n",
    "    1.02,\n",
    "    \"d\",\n",
    "    transform=ax_dict[\"e\"].transAxes,\n",
    "    ha=\"left\",\n",
    "    va=\"bottom\",\n",
    "    fontweight=\"bold\",\n",
    "    fontsize=7\n",
    ")\n",
    "\n",
    "\n",
    "# panel e tSNE\n",
    "plot_scatter(ax=ax_dict[\"f\"], x=tsne_embd, s=2, y=\"k\", scalebar=False, alpha=1.0)\n",
    "ax_dict[\"f\"].set_title(\"t-SNE\")\n",
    "ax_dict[\"f\"].text(\n",
    "    0.2,\n",
    "    1.02,\n",
    "    \"e\",\n",
    "    transform=ax_dict[\"f\"].transAxes,\n",
    "    ha=\"left\",\n",
    "    va=\"bottom\",\n",
    "    fontweight=\"bold\",\n",
    "    fontsize=7\n",
    ")\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, f\"fig_1.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "594ddcce85813638"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "d932e31550f5da8b"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
