{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Eigen space\n",
    "\n",
    "\n",
    "In this tutorial, we introduce the notion of group manifold graph eigenspace. It is defined as the eigen decomposition of the symmetric normalized graph laplacian. The eigenvalues have to be interpreted as graph frequencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from gechebnet.graphs.graphs import SE2GEGraph, SO3GEGraph, S2GEGraph, R2GEGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_eigenvalues(graph, indices):\n",
    "    \n",
    "    eigenval, _ = graph.get_eigen_space()\n",
    "    \n",
    "    fig = plt.figure()\n",
    "    \n",
    "    plt.scatter(indices, eigenval[indices], c=\"firebrick\")\n",
    "        \n",
    "    plt.xlabel(r\"$k$\")\n",
    "    plt.ylabel(r\"$\\lambda_k$\")\n",
    "    \n",
    "    fig.tight_layout()\n",
    "\n",
    "def plot_eigenspace(graph, indices, size):\n",
    "    M, L = size\n",
    "    K = len(indices)\n",
    "\n",
    "    _, eigenvec = graph.get_eigen_space()\n",
    "    eigenvec = torch.from_numpy(eigenvec)\n",
    "    eigenvec[:,0] = 0.\n",
    "        \n",
    "    fig = plt.figure(figsize=(3*K, 3*L))\n",
    "    \n",
    "    X, Y, Z = graph.cartesian_pos()\n",
    "    \n",
    "    for l in range(L):\n",
    "        for k in range(K):\n",
    "            ax = fig.add_subplot(L, K, l * K + k + 1)\n",
    "            ax.scatter(X[l*M:(l+1)*M], Y[l*M:(l+1)*M], c=eigenvec[l*M:(l+1)*M, indices[k]], cmap=cm.RdBu)\n",
    "            ax.axis(\"off\")\n",
    "            \n",
    "    fig.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the eigen space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph = S2GEGraph(\n",
    "    size=[642, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigen_values, eigen_vectors = s2_graph.get_eigen_space()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eigen_values.min(), eigen_values.max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the eigen space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Translation group $\\mathbb{R}^2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r2_graph = R2GEGraph(\n",
    "    [28,28, 1],\n",
    "    K=8,\n",
    "    sigmas=(1., 1., 1.),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenvalues(r2_graph, np.arange(25))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenspace(r2_graph, torch.arange(10), (784, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Roto-translation group $SE(2)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "se2_graph = SE2GEGraph(\n",
    "    [28,28, 6],\n",
    "    K=16,\n",
    "    sigmas=(1., 0.1, 0.0026),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenvalues(se2_graph, np.arange(25))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenspace(se2_graph, torch.arange(10), (784, 6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1-sphere $S(2)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2_graph = S2GEGraph(\n",
    "    size=[642, 1],\n",
    "    K=8,\n",
    "    sigmas=[1., 1., 1.],\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenvalues(s2_graph, np.arange(50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenspace(s2_graph, torch.arange(10), (642, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3-d rotation group $SO(3)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "so3_graph = SO3GEGraph(\n",
    "    size=[642, 6],\n",
    "    K=16,\n",
    "    sigmas=(1., .1, 10.0 / 642),\n",
    "    path_to_graph=\"saved_graphs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenvalues(so3_graph, np.arange(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_eigenspace(so3_graph, torch.arange(10), (642, 6))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
