{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=1\n",
    "import matplotlib.pyplot as plt\n",
    "from persim import plot_diagrams\n",
    "from utils.utils import compute_ph, get_path\n",
    "from utils.pd_utils import get_persistent_feature_id\n",
    "from utils.fig_utils import plot_edges_on_scatter\n",
    "from utils.dist_utils import get_dist\n",
    "from utils.toydata_utils import get_toy_data\n",
    "from vis_utils.plot import plot_scatter\n",
    "import os\n",
    "from matplotlib.patches import Circle, Polygon\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "style_file = \"utils.style\"\n",
    "plt.style.use(style_file)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9784cdef92c0071c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "root_path = get_path(\"data\")\n",
    "fig_path = os.path.join(root_path, \"figures\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "137b07a451d2a5d5"
  },
  {
   "cell_type": "markdown",
   "source": [
    "# PH fig"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "2d2a415d2f5cde7b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# get data\n",
    "sigma = 0.25\n",
    "seed = 3\n",
    "n = 25\n",
    "d = 2\n",
    "data = get_toy_data(dataset=\"toy_circle\", n=n, d=d, seed=seed, gaussian={\"sigma\":sigma})"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "a5b6641d33d0464d"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# compute / load PH result\n",
    "file_name =f\"toy_circle_{n}_d_{d}_ortho_gauss_sigma_{sigma}_seed_{seed}_euclidean\"\n",
    "dataset = \"toy_circle\"\n",
    "res = compute_ph(dist=get_dist(data, distance=\"euclidean\"),\n",
    "                 file_name=file_name,\n",
    "                 root_dir=root_path,\n",
    "                 dataset=\n",
    "                 dataset, \n",
    "                 dim=1,\n",
    "                 delete_dists=True,\n",
    "                 verbose=True,\n",
    "                 force_recompute=False)\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "db4711e42a298385"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# get index of most persistent feature\n",
    "ind1 = get_persistent_feature_id(res, m=1, dim=1)\n",
    "ind2 = get_persistent_feature_id(res, m=2, dim=1)\n",
    "print(ind1, ind2)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "43ebfb6836f71a97"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# birth / death times of most and second most peristent feature\n",
    "birth1, death1 = res[\"dgms\"][1][ind1]\n",
    "birth2, death2 = res[\"dgms\"][1][ind2]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "28abd6528bcae58c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot figure\n",
    "plt.rcParams['text.usetex'] = False\n",
    "\n",
    "fig, ax = plt.subplots(ncols=4, figsize=(5.5, 1.5))\n",
    "\n",
    "radii = [birth2, birth1, death1]\n",
    "times = [\"$t_1$\", \"$t_2$\", \"$t_3$\"]\n",
    "\n",
    "for i, radius in enumerate(radii):\n",
    "    # plot points\n",
    "    plot_scatter(x=data, ax=ax[i], s=10, y=\"k\", scalebar=False, alpha=1)\n",
    "    \n",
    "    # plot discs around points\n",
    "    for pt in data:\n",
    "        disc = Circle(pt, radius=radius / 2, facecolor='k', alpha=0.1, clip_on=False, edgecolor=\"none\")  # (x, y), radius, color, and alpha\n",
    "        ax[i].add_patch(disc)\n",
    "        \n",
    "    # print all existing edges\n",
    "    dist = get_dist(data, distance=\"euclidean\")\n",
    "    mask = dist < radius\n",
    "    mask_new = np.isclose(dist, radius)\n",
    "    \n",
    "    mask[mask_new] = False\n",
    "    mask = np.triu(mask, k=1)\n",
    "    mask_new = np.triu(mask_new, k=1)\n",
    "    \n",
    "    xx, yy = np.meshgrid(np.arange(len(data)), np.arange(len(data)))\n",
    "    coords = np.stack([xx, yy], axis=-1)\n",
    "    edges = coords[mask]\n",
    "    \n",
    "    plot_edges_on_scatter(ax=ax[i], edge_idx=edges, x=data, color=\"k\",linewidth=0.5)\n",
    "    plot_edges_on_scatter(ax=ax[i], edge_idx=coords[mask_new], x=data, color=\"k\",linewidth=0.5, linestyle=\"dotted\")\n",
    "    \n",
    "   \n",
    "    # plot triangles \n",
    "    for edge in edges:\n",
    "        for pt_id in range(len(data)):\n",
    "            if pt_id == edge[0] or pt_id == edge[1]:\n",
    "                continue\n",
    "            pt = data[pt_id]\n",
    "            triangle = np.array([pt, data[edge[0]], data[edge[1]]])\n",
    "            if np.all((get_dist(triangle, distance=\"euclidean\") < radius * 1.0001)) :\n",
    "                # Create a polygon representing the triangle\n",
    "                triangle = Polygon(triangle, closed=True, edgecolor='none', facecolor='lightblue', alpha=1.0)\n",
    "                \n",
    "                # Add the triangle to the axis\n",
    "                ax[i].add_patch(triangle)\n",
    "        \n",
    "    # filtration times\n",
    "    ax[i].text(0.5, -0.2, times[i],\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center',\n",
    "        transform=ax[i].transAxes)\n",
    "                \n",
    "# plot persistence diagram\n",
    "plot_diagrams(res[\"dgms\"], ax=ax[-1], plot_only=[1], color=\"k\", colormap=style_file)\n",
    "ax[-1].set_yticks([death1])\n",
    "ax[-1].set_yticklabels([\"$t_3$\"])\n",
    "ax[-1].set_ylabel(\"Death\", labelpad=-5)\n",
    "ax[-1].set_xticks([birth2, birth1])\n",
    "ax[-1].set_xticklabels([\"$t_1$\", \"$t_2$\"])\n",
    "ax[-1].legend().set_visible(False)\n",
    "\n",
    "# plot dotted lines in persistence diagram\n",
    "ax[-1].plot([0, birth1], [death1, death1], zorder=-1, c=\"k\", linestyle=\"dotted\")\n",
    "ax[-1].plot([birth1, birth1], [0, death1], zorder=-1, c=\"k\", linestyle=\"dotted\")\n",
    "ax[-1].plot([birth2, birth2], [0, death2], zorder=-1, c=\"k\", linestyle=\"dotted\")\n",
    "\n",
    "\n",
    "mid = (birth1 + death1) / 2\n",
    "ax[-1].plot([birth1, mid], [death1, mid], zorder=-1, c=\"g\")\n",
    "\n",
    "mid2 = (birth2 + death2) / 2\n",
    "ax[-1].plot([birth2, mid2], [death2, mid2], zorder=-1, c=\"r\")\n",
    "\n",
    "\n",
    "# plot detection score formula\n",
    "ax[-1].text(0.85, 1.05, \"$p_1$\",\n",
    "            color=\"g\",\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "ax[-1].text(0.66, 0.75, \"$p_2$\",\n",
    "            color=\"r\",\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "\n",
    "left = 0.115\n",
    "ax[-1].text(left +0.75, 0.5, 'Score =',\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "\n",
    "left = left + 0.025\n",
    "ax[-1].text(left +0.92, 0.55, '$p_1$',\n",
    "            color=\"g\",\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "ax[-1].text(left +0.98, 0.55, '$-$',\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "\n",
    "ax[-1].text(left +1.05, 0.55, '$p_2$',\n",
    "            color=\"r\",\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "\n",
    "ax[-1].text(left +0.98, 0.535, '______',\n",
    "            color=\"k\",\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "\n",
    "ax[-1].text(left +0.98, 0.46, '$p_1$',\n",
    "            color=\"g\",\n",
    "        horizontalalignment='center',\n",
    "        verticalalignment='center')\n",
    "\n",
    "\n",
    "\n",
    "ax[0].set_title(\"a\", fontweight=\"bold\", loc=\"left\", )\n",
    "ax[-1].set_title(\"b\", fontweight=\"bold\", loc=\"left\", ha=\"right\")\n",
    "\n",
    "\n",
    "fig.savefig(os.path.join(fig_path, \"fig_ph.pdf\"))"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "25a68e24fa7d9890"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "959e3507b7b9c7da"
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "conda-env-ph-py",
   "language": "python",
   "display_name": "Python [conda env:ph]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
