{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "149151aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "import gudhi as gd\n",
    "from sbarlay.sbarlay import *\n",
    "import multipers as mp\n",
    "from multipers.ml.diff import GaussianKDE\n",
    "import matplotlib.pyplot as plt\n",
    "import multipers.ml.multi as mmm\n",
    "from warnings import warn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d428ddc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _RipsCodensity(DX, codensities, max_edge, dimension):\n",
    "    \n",
    "    st = gd.RipsComplex(distance_matrix=DX, max_edge_length=max_edge).create_simplex_tree()\n",
    "    st = mp.SimplexTreeMulti(st, num_parameters=2)\n",
    "    st.fill_lowerstar(codensities, parameter=1)\n",
    "    st.collapse_edges(num=100).expansion(dimension+1)\n",
    "    (sm, weights), = mp.signed_measure(st, degree=dimension, mass_default=None)\n",
    "    indices, not_found = mp.signed_measure_indices(st, sm, simplices_dimensions=[1,0])\n",
    "    if len(not_found != 0):\n",
    "        warn(f\"Some points are ignored as they cannot be found in the simplextree, {not_found.shape}\")\n",
    "    simplices = [s for s,_ in st.get_simplices()]\n",
    "    found_point_indices = [ i for i in range(len(weights)) if np.all(indices[i] != -1)]\n",
    "    edges = np.array([simplices[indices[i,0]] for i in found_point_indices])\n",
    "    nodes = np.array([simplices[indices[i,1]][0] for i in found_point_indices])\n",
    "    weights_diff = weights[found_point_indices]\n",
    "    return found_point_indices, nodes, edges, weights_diff\n",
    "\n",
    "class RipsCodensityLayer(tf.keras.layers.Layer):\n",
    "    \n",
    "    def __init__(self, dimension, maximum_edge_length=np.inf, min_persistence=None, padval=1000, **kwargs):\n",
    "        super().__init__(dynamic=True, **kwargs)\n",
    "        self.max_edge = maximum_edge_length\n",
    "        self.dimension = dimension\n",
    "        self.padval = padval\n",
    "        self.min_persistence = min_persistence if min_persistence is not None else 0.\n",
    "\n",
    "    def call(self, X, theta):\n",
    "        DX = tf.norm(tf.expand_dims(X, 1)-tf.expand_dims(X, 0), axis=2)\n",
    "        self.codens = tf.Variable(\n",
    "        [-1.3901, -1.7220, -1.1447, -1.4136, -1.0946, -1.4318, -1.3884, -1.3557,\n",
    "        -0.9871, -1.3174, -1.4331, -1.7350, -0.9161, -1.3497, -1.5657, -0.9497,\n",
    "        -1.3774, -1.1125, -1.6157, -1.4984, -1.6980, -1.1297, -0.9652, -1.4513,\n",
    "        -1.1746, -1.3720, -1.2041, -1.8286, -1.3466, -1.3318, -0.9676, -1.2171,\n",
    "        -1.2352, -1.1926, -1.0565, -1.9565, -1.2316, -1.8219, -1.4177, -1.2469,\n",
    "        -0.9262, -0.9748, -1.2618, -1.3744, -1.4091, -2.0764, -0.9514, -1.7658,\n",
    "        -0.9698, -1.6062, -1.4842, -1.1646, -0.9586, -1.8077, -2.1471, -1.1385,\n",
    "        -1.3899, -1.3996, -1.3422, -1.6799, -1.8039, -1.4895, -1.2648, -1.2062,\n",
    "        -1.2728, -1.6234, -1.3586, -1.3187, -0.9552, -1.7609, -1.3373, -1.3333,\n",
    "        -1.3203, -1.4326, -1.3757, -1.8348, -1.4038, -1.2347, -1.4331, -1.3635,\n",
    "        -1.5232, -1.3373, -1.5991, -1.0574, -1.4120, -1.8132, -1.2772, -1.1879,\n",
    "        -1.1541, -1.1723, -1.4209, -1.5606, -1.0525, -1.4601, -1.3841, -1.3215,\n",
    "        -1.5042, -0.9478, -1.6197, -1.6186, -1.1602, -1.0343, -1.3410, -1.4178,\n",
    "        -0.9534, -1.2695, -1.1301, -1.3179, -1.6982, -1.5269, -1.3124, -1.2743,\n",
    "        -1.2356, -1.2697, -0.9687, -0.9595, -1.0945, -1.1607, -1.5297, -1.7146,\n",
    "        -1.1662, -1.3248, -1.2238, -1.0883, -1.6167, -1.9065, -1.3297, -1.2593,\n",
    "        -1.4276, -1.6785, -1.4536, -1.2307, -1.1671, -1.1493, -2.0133, -2.1172,\n",
    "        -1.2238, -1.5108, -1.5919, -1.3984, -0.8846, -1.1709, -1.0766, -1.2753,\n",
    "        -1.2460, -1.2673, -1.4949, -1.4398, -1.6837, -1.4543, -1.1893, -1.3131,\n",
    "        -1.2449, -1.7273, -1.2407, -0.8973, -1.1475, -1.3489, -1.1957, -1.0751,\n",
    "        -1.2347, -1.1927, -1.3560, -1.4062, -1.2729, -1.2898, -1.3917, -1.4566,\n",
    "        -1.1750, -1.3343, -1.0889, -1.3405, -1.4757, -1.4280, -1.3988, -1.2733,\n",
    "        -1.3892, -1.1645, -1.1565, -1.7365, -1.2371, -1.1546, -1.4067, -1.6533,\n",
    "        -1.3940, -1.7021, -1.2896, -1.3686, -1.2515, -1.3647, -0.9588, -1.4187,\n",
    "        -1.3251, -1.1921, -1.1291, -1.0779, -1.4898, -1.2068, -1.3891, -1.3003,\n",
    "        -1.4790, -1.3070, -1.4130, -1.2259, -0.9488, -1.4820, -1.1851, -0.9258,\n",
    "        -1.3119, -0.9778, -1.1690, -1.4958, -1.8655, -1.4715, -1.3867, -1.2933,\n",
    "        -1.1824, -0.9791, -2.0008, -1.4693, -1.5296, -1.5331, -1.0879, -1.1475,\n",
    "        -1.2261, -1.4083, -1.6689, -1.4281, -1.5135, -1.2113, -1.3529, -1.2367,\n",
    "        -1.0571, -1.7101, -1.6417, -1.3080, -1.0146, -2.2244, -1.4485, -1.2705,\n",
    "        -1.1996, -1.1935, -1.1618, -1.6863, -1.3015, -1.3616, -1.8528, -1.2075,\n",
    "        -1.2488, -1.5426, -1.3641, -1.3955, -1.2461, -1.5052, -1.1103, -1.2446,\n",
    "        -1.4084, -1.1685, -1.4686, -1.8730, -1.1111, -1.3577, -1.4075, -1.1968,\n",
    "        -1.9305, -1.5323, -1.5830, -0.9752, -1.1466, -1.5438, -1.3257, -1.2460,\n",
    "        -1.6074, -0.9470, -1.4130, -1.4489, -1.2338, -1.2790, -1.2687, -1.4552,\n",
    "        -1.3023, -1.4768, -1.6317, -1.3923, -1.7999, -1.5113, -0.9841, -1.2715,\n",
    "        -1.5125, -1.2054, -1.6492, -1.0093, -1.0415, -0.9286, -1.1237, -1.4140,\n",
    "        -1.1698, -1.2378, -1.2206, -1.4478],\n",
    "            dtype=tf.float32, trainable=True)\n",
    "        #self.codens = tf.math.log(tf.math.reduce_sum(tf.math.exp(-tf.math.square(DX)/theta), axis=1))\n",
    "        fp_indices, nodes, edges, w_diff = _RipsCodensity(DX.numpy(), self.codens, self.max_edge, self.dimension)\n",
    "        pts_diff_0 = tf.reshape(tf.gather_nd(DX, edges), [-1,1])\n",
    "        pts_diff_1 = tf.reshape(tf.gather(self.codens, nodes), [-1,1])\n",
    "        sm_diff = tf.concat([pts_diff_0, pts_diff_1], axis=1)\n",
    "        pos_idxs, neg_idxs = np.argwhere(w_diff > 0).ravel(), np.argwhere(w_diff < 0).ravel()\n",
    "        pos_weights, neg_weights = w_diff[pos_idxs], -w_diff[neg_idxs]\n",
    "        pos_sm_diff = tf.repeat(tf.gather(sm_diff, pos_idxs), pos_weights, axis=0)\n",
    "        neg_sm_diff = tf.repeat(tf.gather(sm_diff, neg_idxs), neg_weights, axis=0)\n",
    "        npos, nneg = pos_sm_diff.shape[0], neg_sm_diff.shape[0]\n",
    "        padded_pos_sm_diff = tf.pad(pos_sm_diff, np.array([[0,self.padval-npos],[0,0]]))\n",
    "        padded_neg_sm_diff = tf.pad(neg_sm_diff, np.array([[0,self.padval-nneg],[0,0]]))\n",
    "        padded_pos_sm_diff_pad = tf.constant([[False for _ in range(npos)] + [True for _ in range(self.padval-npos)]], dtype=tf.bool)\n",
    "        padded_neg_sm_diff_pad = tf.constant([[False for _ in range(nneg)] + [True for _ in range(self.padval-nneg)]], dtype=tf.bool)\n",
    "        return padded_pos_sm_diff, padded_neg_sm_diff, padded_pos_sm_diff_pad, padded_neg_sm_diff_pad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c5a2fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1)\n",
    "Xinit = np.array(np.random.uniform(high=1., low=-1., size=(300,2)), dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8833f463",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "X = tf.Variable(Xinit, dtype=tf.float32, trainable=True)\n",
    "theta = tf.Variable(1., dtype=tf.float32, trainable=False)\n",
    "rc_layer = RipsCodensityLayer(dimension=1, maximum_edge_length=.5, padval=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00a355a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = [{\n",
    "              \"phi\":     GaussianSBarlayPhi([50 for _ in range(2)], [[0,1],[0,1]], [1. for _ in range(2)]), \n",
    "              \"weight\":  PowerSBarlayWeight(float(1),int(0)),\n",
    "              \"perm_op\": tf.math.reduce_sum, \n",
    "              \"rho\":     tf.keras.layers.Identity()\n",
    "          }]\n",
    "\n",
    "rho = tf.keras.models.Sequential()\n",
    "rho.add(tf.keras.layers.Flatten())\n",
    "rho.add(tf.keras.layers.Dense(32))\n",
    "sbarlay_model = SBarlayModel(params, rho=rho, plot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c58eca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = tf.keras.optimizers.schedules.InverseTimeDecay(initial_learning_rate=1e-1, decay_steps=10, decay_rate=.01)\n",
    "optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf32a941",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "losses = []\n",
    "for ep in range(num_epochs):\n",
    "    \n",
    "    with tf.GradientTape() as tape:\n",
    "        psm_diff, nsm_diff, psm_diff_p, nsm_diff_p = rc_layer(X, theta)\n",
    "        representation = sbarlay_model([ [\n",
    "            tf.reshape(psm_diff,   [1] + psm_diff.shape), \n",
    "            tf.reshape(nsm_diff,   [1] + nsm_diff.shape),\n",
    "            tf.reshape(psm_diff_p, [1] + psm_diff_p.shape),\n",
    "            tf.reshape(nsm_diff_p, [1] + nsm_diff_p.shape),\n",
    "        ] ])\n",
    "        loss = tf.math.reduce_sum(representation)\n",
    "        \n",
    "    if ep % 1 == 0:\n",
    "        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,6))\n",
    "        a,b = axes\n",
    "        a.scatter(Xinit[:,0], Xinit[:,1])\n",
    "        b.scatter(psm_diff.numpy()[:,0], psm_diff.numpy()[:,1])\n",
    "        b.scatter(nsm_diff.numpy()[:,0], nsm_diff.numpy()[:,1])\n",
    "        plt.show()\n",
    "    \n",
    "    gradients = tape.gradient(loss, [X]) # + sbarlay_model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(gradients, [X])) # + sbarlay_model.trainable_variables))\n",
    "    losses.append(loss.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f8f4354",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(losses)\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
