{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gechebnet.graphs.graphs import SE2GEGraph, SO3GEGraph, S2GEGraph, R2GEGraph, RandomSubGraph\n",
    "from gechebnet.graphs.viz import visualize_graph_signal\n",
    "from gechebnet.utils.utils import delta_kronecker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "se2_graph = SE2GEGraph(\n",
    "    size=[25, 25, 6],\n",
    "    K=16,\n",
    "    sigmas= (1., 0.1, 0.0026), #(1., 1., 2.048 / (28 ** 2)), #0.1, 2.048 / (28 ** 2)),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r2_graph = R2GEGraph(\n",
    "    size=[25, 25, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_graph_diffusion(graph, tau, tol=1e-4):\n",
    "    input = delta_kronecker((graph.num_vertices), graph.centroid_vertex)\n",
    "    \n",
    "    L, M = graph.dim\n",
    "    X, Y, Z = graph.cartesian_pos()\n",
    "    \n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    \n",
    "    graph_kernel = graph.diff_kernel(lambda x: np.exp(-tau*x))\n",
    "    signal = (graph_kernel @ input)\n",
    "    mask = signal > tol\n",
    "\n",
    "    ax = fig.add_subplot(projection='3d')\n",
    "    ax.scatter(X, Y, Z, color=\"none\", edgecolor=\"black\", s=50, alpha=0.2)\n",
    "    ax.scatter(X[mask], Y[mask], Z[mask], c=signal[mask], s=50, cmap=\"cividis\")\n",
    "    ax.set_xlabel(fr\"$x$\")\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.set_zticks([])\n",
    "    #fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_graph_diffusion(r2_graph, 10.)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
