{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Group manifold graphs\n",
    "\n",
    "\n",
    "In this tutorial, we introduce the notion of group manifold graph, a discretization of a Riemannian manifold. At the moment, four manifolds are available: the translation group $\\mathbb{R}^2$, the roto-translation group $SE(2)$, the 3d rotation group $SO(3)$ and the 1-sphere $S(2)$.\n",
    "\n",
    "We define such a graph as following:\n",
    "- the vertices corresponds to **uniformly sampled** elements on the manifold,\n",
    "- the edges connects each vertex to its **K nearest neighbors**, w.r.t an **anisotropic riemannian distance**,\n",
    "- the edges' weights are computed by a **gaussian weight kernel** applied on the riemannian distance between vertices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a graph manifold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gechebnet.graphs.graphs import SE2GEGraph, SO3GEGraph, S2GEGraph, R2GEGraph, RandomSubGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r2_graph = R2GEGraph(\n",
    "    size=[28, 28, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = 0.1\n",
    "xi = \n",
    "\n",
    "se2_graph = SE2GEGraph(\n",
    "    size=[28, 28, 6],\n",
    "    K=16,\n",
    "    sigmas=(1., 1/eps**2, 2.048 / (28 ** 2)),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph = S2GEGraph(\n",
    "    size=[642, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "so3_graph = SO3GEGraph(\n",
    "    size=[642, 6],\n",
    "    K=32,\n",
    "    sigmas=(1., .1, 10/642),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get informations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.is_connected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.is_undirected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.manifold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.num_vertices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.num_edges # number of directed edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.vertex_index[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.vertex_attributes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.vertex_beta[:10], s2_graph.vertex_gamma[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.edge_index[:10] # dim 0 is source, dim 1 is target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.edge_weight[:10] # dim 0 is source, dim 1 is target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.edge_sqdist[:10] # dim 0 is source, dim 1 is target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph.neighborhood(9) # neighbors index, edges' weights and squared riemannian distance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Static visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_graph(graph, size):\n",
    "    M, L = size\n",
    "\n",
    "    fig = plt.figure(figsize=(5*L, 5))\n",
    "    \n",
    "    X, Y, Z = graph.cartesian_pos()\n",
    "\n",
    "    for l in range(L):\n",
    "        ax = fig.add_subplot(1, L, l + 1, projection=\"3d\")\n",
    "        ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], Z[l*M:(l+1)*M], c=\"firebrick\")\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "    fig.tight_layout()\n",
    "\n",
    "def plot_graph_neighborhood(graph, index, size):\n",
    "    M, L = size\n",
    "\n",
    "    fig = plt.figure(figsize=(5, 5))\n",
    "    \n",
    "    X, Y, Z = graph.cartesian_pos()\n",
    "\n",
    "    neighbors_indices, neighbors_weights, _ = graph.neighborhood(index)\n",
    "    weights = torch.zeros(graph.num_vertices)\n",
    "    weights[neighbors_indices] = neighbors_weights\n",
    "    for l in range(L):\n",
    "        ax = fig.add_subplot(L, 1, l + 1, projection=\"3d\")\n",
    "        ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], Z[l*M:(l+1)*M], c=weights[l*M:(l+1)*M], cmap=cm.PuRd)\n",
    "        ax.axis(\"off\")\n",
    "\n",
    "    fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graph(s2_graph, [642, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graph_neighborhood(s2_graph, 406, [642, 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dynamic visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gechebnet.graphs.viz import visualize_graph, visualize_graph_neighborhood, visualize_graph_signal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eps = 0.1\n",
    "xi = 6 / (28 ** 2)\n",
    "\n",
    "se2_graph = SE2GEGraph(\n",
    "    size=[28, 28, 6],\n",
    "    K=32,\n",
    "    sigmas=(1., 1/eps, xi),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_graph_neighborhood(se2_graph, 156)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "so3_graph = SO3GEGraph(\n",
    "    size=[642, 6],\n",
    "    K=16,\n",
    "    sigmas=(1., .1, 10/642),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_graph(so3_graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "signal = torch.rand(s2_graph.num_vertices)\n",
    "visualize_graph_signal(s2_graph, signal)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random sub graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_subgraph = RandomSubGraph(s2_graph)\n",
    "random_subgraph.num_vertices, random_sub_graph.num_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_subgraph.reinit()\n",
    "random_subgraph.edge_sampling(0.9)\n",
    "random_subgraph.num_vertices, random_sub_graph.num_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_subgraph.reinit()\n",
    "random_subgraph.node_sampling(0.5)\n",
    "random_subgraph.num_vertices, random_sub_graph.num_edges"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
