{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "from scipy.spatial import distance_matrix\n",
    "import ripserplusplus as rpp_py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('lib/')\n",
    "import ph_simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Graph:\n",
    "    def __init__(self, n_nodes, edges, w):\n",
    "        self.n_nodes = n_nodes\n",
    "        self.edges = edges\n",
    "        self.w = w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_graph(a):\n",
    "    n_nodes = a.shape[0]\n",
    "    edges = []\n",
    "    w = []\n",
    "    \n",
    "    for i in range(n_nodes):\n",
    "        for j in range(i):\n",
    "            edges.append((i, j))\n",
    "            w.append(a[i, j])\n",
    "            \n",
    "    return Graph(n_nodes, edges, w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 49.6 ms, sys: 4.18 ms, total: 53.7 ms\n",
      "Wall time: 51.9 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "n_tasks = 50\n",
    "tasks = []\n",
    "\n",
    "for i in range(n_tasks):\n",
    "    np.random.seed(i)\n",
    "\n",
    "    N_VERT = 50\n",
    "    DIM_A = 5\n",
    "\n",
    "    cloud_a = np.random.random(size=(N_VERT, DIM_A)) \n",
    "    a = distance_matrix(cloud_a, cloud_a)\n",
    "    \n",
    "    tasks.append((a, to_graph(a)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### transform to ph_simple format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "edges_cnt = 0\n",
    "for t in tasks:\n",
    "    edges_cnt += len(t[1].edges)\n",
    "\n",
    "edges = np.zeros((2, edges_cnt), dtype = np.int32)\n",
    "edge_ptr = np.zeros((len(tasks)+1), dtype = np.int32)\n",
    "node_ptr = np.zeros((len(tasks)+1), dtype = np.int32)\n",
    "w = np.zeros(edges_cnt)\n",
    "\n",
    "edges_cnt = 0\n",
    "nodes_cnt = 0\n",
    "\n",
    "for i, t in enumerate(tasks):\n",
    "    edge_ptr[i] = edges_cnt\n",
    "    edges_cnt += len(t[1].edges)\n",
    "    \n",
    "    node_ptr[i] = nodes_cnt\n",
    "    nodes_cnt += t[0].shape[0]\n",
    "\n",
    "edge_ptr[-1] = edges_cnt\n",
    "node_ptr[-1] = nodes_cnt\n",
    "\n",
    "for i, t in enumerate(tasks):\n",
    "    edges[:, edge_ptr[i]:edge_ptr[i+1]] = np.array(t[1].edges, dtype = np.int32).T + node_ptr[i]\n",
    "    w[edge_ptr[i]:edge_ptr[i+1]] = np.array(t[1].w, dtype = np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[   0   50  100  150  200  250  300  350  400  450  500  550  600  650\n",
      "  700  750  800  850  900  950 1000 1050 1100 1150 1200 1250 1300 1350\n",
      " 1400 1450 1500 1550 1600 1650 1700 1750 1800 1850 1900 1950 2000 2050\n",
      " 2100 2150 2200 2250 2300 2350 2400 2450 2500]\n"
     ]
    }
   ],
   "source": [
    "print(node_ptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[   1,    2,    2, ..., 2499, 2499, 2499],\n",
       "       [   0,    0,    1, ..., 2496, 2497, 2498]], dtype=int32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "61250"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edges_cnt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run ph_simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "h0_idx = np.full(nodes_cnt, -1, dtype = np.int32)\n",
    "h0_e = np.full(edges_cnt, -1, dtype = np.int32)\n",
    "h1_e = np.full(edges_cnt, -1, dtype = np.int32)\n",
    "\n",
    "ph_simple.calc_barcodes_batch_cycles(len(tasks), edges, w, edge_ptr, node_ptr, h0_idx, h0_e, h1_e, 1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "h0_c = []\n",
    "\n",
    "for idx in h0_idx:\n",
    "    if idx > -1:\n",
    "        h0_c.append(w[idx])\n",
    "    else:\n",
    "        h0_c.append(None)\n",
    "        \n",
    "h0_list = []\n",
    "\n",
    "for i in range(len(node_ptr)-1):\n",
    "    i1, i2 = node_ptr[i], node_ptr[i+1]\n",
    "    h0_list.append(h0_c[i1:i2])\n",
    "    \n",
    "#\n",
    "#\n",
    "#\n",
    "h1_list = [[] for _ in range(n_tasks)]\n",
    "\n",
    "for i, graph_num in enumerate(h1_e):\n",
    "    if graph_num > -1:\n",
    "        h1_list[graph_num].append(w[i])\n",
    "\n",
    "#\n",
    "#\n",
    "#   \n",
    "h0_list_v2 = [[] for _ in range(n_tasks)]\n",
    "\n",
    "for i, graph_num in enumerate(h0_e):\n",
    "    if graph_num > -1:\n",
    "        h0_list_v2[graph_num].append(w[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_diff(res0, ripser_h0):\n",
    "    error = False\n",
    "\n",
    "    for v1, v2 in zip(res0, ripser_h0):\n",
    "        if abs(v1 - v2) > 1e-6:\n",
    "            error = True\n",
    "            break\n",
    "            \n",
    "    return not error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### compare with ripser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n",
      "H0 True True , H1 True\n"
     ]
    }
   ],
   "source": [
    "for i, (a, _) in enumerate(tasks):\n",
    "    barcodes = rpp_py.run(\"--format distance --mode default --dim 1\", a) \n",
    "    #print('default\\nbarcodes for matrix A:', barcodes)\n",
    "    \n",
    "    h0_death_ripser = [x[1] for x in barcodes['dgms'][0][:]]\n",
    "    h1_birth_ripser = sorted([x[0] for x in barcodes['dgms'][1]])\n",
    "    \n",
    "    print('H0', check_diff(h0_death_ripser, h0_list[i]), check_diff(h0_death_ripser, sorted(h0_list_v2[i])),\\\n",
    "          ', H1', check_diff(h1_birth_ripser, sorted(h1_list[i])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
