{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural network layers\n",
    "\n",
    "In this tutorial, we introduce three kinds of neural network layers we are using in this thesis. The Chebyshev convolutional layer is a spectral method and has a diffusion effect on a original signal. The pooling and unpooling layer are used to modify an image resolution, by down-sampling and reduction (pooling) or by up-sampling and expansion (unpooling)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gechebnet.graphs.graphs import SE2GEGraph, SO3GEGraph, R2GEGraph, S2GEGraph\n",
    "from gechebnet.nn.layers.convs import ChebConv\n",
    "from gechebnet.nn.layers.pools import SE2SpatialPool, SO3SpatialPool\n",
    "from gechebnet.nn.layers.unpools import SE2SpatialUnpool, SO3SpatialUnpool\n",
    "from gechebnet.utils.utils import delta_kronecker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_signal(graph, signal, size):\n",
    "    M, L = size\n",
    "        \n",
    "    fig = plt.figure(figsize=(3*L, 3))\n",
    "    \n",
    "    X, Y, Z = graph.cartesian_pos()\n",
    "    \n",
    "    for l in range(L):\n",
    "        ax = fig.add_subplot(1, L, l + 1)\n",
    "        ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], Z[l*M:(l+1)*M], c=signal[...,l*M:(l+1)*M], cmap=cm.PiYG)\n",
    "        ax.axis(\"off\")\n",
    "            \n",
    "    fig.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SE(2) Group Manifold Graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convolutional layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "se2_graph = SE2GEGraph(\n",
    "    [28, 28, 6],\n",
    "    K=16,\n",
    "    sigmas=(1., 0.1, 0.0026),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_channels = 1\n",
    "out_channels = 1\n",
    "kernel_size = 4\n",
    "conv = ChebConv(in_channels, out_channels, kernel_size, se2_graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    input = delta_kronecker((1, 1, 28*28*6), (0,0,406))\n",
    "    output = conv(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_signal(se2_graph, output, (784,6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pooling and unpooling layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "up_se2_graph = R2GEGraph(\n",
    "    [40, 40, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")\n",
    "\n",
    "se2_graph = R2GEGraph(\n",
    "    [20, 20, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")\n",
    "\n",
    "down_se2_graph = R2GEGraph(\n",
    "    [10, 10, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.rand(20*20*1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pool = SE2SpatialPool(2, (20, 20, 1), \"rand\")\n",
    "unpool = SE2SpatialUnpool(2, (20, 20, 1), \"rand\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_signal(se2_graph, input, (20*20,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    plot_signal(down_se2_graph, pool(input), (10*10,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    plot_signal(up_se2_graph, unpool(input), (40*40,1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SO(3) Group Manifold Graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Convolutional layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "so3_graph = SO3GEGraph(\n",
    "    size=[642, 6],\n",
    "    K=32,\n",
    "    sigmas=(1., .1, 10.0 / 642),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_channels = 1\n",
    "out_channels = 1\n",
    "kernel_size = 4\n",
    "conv = ChebConv(in_channels, out_channels, kernel_size, so3_graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    input = delta_kronecker((1, 1, 642*6), (0,0,143))\n",
    "    output = conv(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_signal(so3_graph, output, (642,6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pooling and unpooling layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "up_so3_graph = S2GEGraph(\n",
    "    size=[2562, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")\n",
    "\n",
    "so3_graph = S2GEGraph(\n",
    "    [642, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")\n",
    "\n",
    "down_so3_graph = S2GEGraph(\n",
    "    [162, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.rand(642*1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pool = SO3SpatialPool(2, (642, 1), \"max\")\n",
    "unpool = SO3SpatialUnpool(2, (642, 1), \"avg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_signal(so3_graph, input, (642,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    plot_signal(down_so3_graph, pool(input), (162,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    plot_signal(up_so3_graph, unpool(input), (2562,1))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
