{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c2e3f52f",
   "metadata": {},
   "source": [
    "# Learning Deep O($n$)-Equivariant Hyperspheres\n",
    "\n",
    "This notebook is meant to be used to verify the theoretical results from the paper, such as the equivariance of the proposed $n \\text{D}$ neurons.\n",
    "\n",
    "The complete implementation of our method, as well as experiment-related scripts, will be made publicly available upon acceptance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c40c580",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.set_printoptions(precision=4, sci_mode=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b05bfced",
   "metadata": {},
   "source": [
    "### Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cbaa60fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extend_diagonal(R, n):\n",
    "    # given an mxm matrix,\n",
    "    # append its diagonal with ones to an nxn matrix\n",
    "    m, _ = R.shape\n",
    "    assert m <= n\n",
    "    I = torch.eye(n)    \n",
    "    I[:m, :m] = R\n",
    "    return I\n",
    "\n",
    "\n",
    "def compute_rotation_from_two_points(p, q):\n",
    "    ''' \n",
    "    A reflection method:\n",
    "        assuming ||p|| == ||q||\n",
    "        f(A, u) = A - 2 u (u^T S)/||u||^2\n",
    "        S = f(I_n, p+q)\n",
    "        R = f(S, q)\n",
    "\n",
    "    Args: \n",
    "        p, q - torch.Tensor - two nD points, necessarily with ||p|| == ||q||\n",
    "\n",
    "    Return:\n",
    "        R - DxD rotation matrix such that R p = q\n",
    "    '''    \n",
    "    assert len(p.shape) == 2 and p.shape == q.shape\n",
    "    a = torch.abs(p.norm(dim=1, keepdim=True).pow(2) - q.norm(dim=1, keepdim=True).pow(2)).max()\n",
    "\n",
    "    assert a < 1e-5, 'Such a rotation doesn\\'t exist: ||p|| must be equal to ||q||, '+str(a)\n",
    "    B, D = p.shape \n",
    "\n",
    "    def reflection(S, u):   \n",
    "        # reflection of S on hyperplane u:\n",
    "        # S can be a matrix; S and u must have the same number of rows.\n",
    "        assert len(S) == len(u) and S.shape[-1] == u.shape[-1]\n",
    "\n",
    "        v = torch.matmul(u.unsqueeze(1), S) # (Bx1xD)\n",
    "        v = v.squeeze(1) / u.norm(dim=1, keepdim=True)**2 # (BxD) / (Bx1) = (BxD)\n",
    "        M = S - 2 * torch.matmul(u.unsqueeze(-1), v.unsqueeze(1)) # the matmul performs the outer product of u and v            \n",
    "        return M\n",
    "\n",
    "    S = reflection( torch.eye(D).repeat(B, 1, 1).to(p.device), p+q )  # S @ p = -q, S @ q = -p\n",
    "    R = reflection(S, q) # R @ p = q, R.T @ q = p            \n",
    "\n",
    "    return R\n",
    "\n",
    "\n",
    "\n",
    "def random_orthogonal_matrix(n):\n",
    "    # generate a random nxn matrix:\n",
    "    random_matrix = torch.rand(n, n)\n",
    "    \n",
    "    # use the QR decomposition to orthogonalize the matrix:\n",
    "    Q, _ = torch.linalg.qr(random_matrix)\n",
    "    \n",
    "    return Q"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7131f93",
   "metadata": {},
   "source": [
    "### Main class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "760f23ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EquivariantHyperspheres:\n",
    "\n",
    "    def __init__(self, n):\n",
    "        self.n = n\n",
    "        self.p = torch.zeros(n + 1, n)\n",
    "\n",
    "        # calculate vertices of the regular n-simplex:\n",
    "        self.kappa = - (1 + (n + 1)**(1/2)) / (n**(3/2))\n",
    "        self.mu = (1 + 1 / n)**(1/2)\n",
    "        self.p[0, :] = 1 / n**(1/2)\n",
    "        for i in range(1, n + 1):\n",
    "            self.p[i, :] = self.kappa + self.mu * torch.eye(n)[i - 1]\n",
    "\n",
    "        # calculate the simplex change-of-basis matrix M_n:\n",
    "        self.M_n = torch.zeros(n+1, n+1)\n",
    "        for i in range(n + 1):\n",
    "            p_extended = torch.cat((self.p[i, :], torch.tensor([n**(-1/2)])))\n",
    "            p_norm = torch.norm(p_extended)\n",
    "            self.M_n[i, :] = (1 / p_norm) * p_extended\n",
    "                \n",
    "        self.R_T = [torch.eye(n).unsqueeze(0)]\n",
    "        self.R_T += [compute_rotation_from_two_points(self.p[0:1], p.unsqueeze(0)) for p in self.p[1:]]\n",
    "        self.R_T = torch.stack(self.R_T)\n",
    "        \n",
    "        self.p_norm = p_norm\n",
    "        self.p = self.p.T         # n x n+1\n",
    "        self.M_n = self.M_n.T     # n+1 x n+1\n",
    "\n",
    "    def embed(self, x, c, r):\n",
    "        # embed data vector and sphere into R^(n+2):\n",
    "        self.X = torch.cat((x, torch.tensor([-1, -0.5 * torch.norm(x) ** 2])))\n",
    "        self.S = torch.cat((c, torch.tensor([0.5 * (torch.norm(c) ** 2 - r ** 2), 1])))\n",
    "\n",
    "    def B_n(self, R_O):\n",
    "        # construct the filter bank B_n:\n",
    "        B_n = torch.zeros(self.n+1, self.n+2)\n",
    "        for i in range(self.n + 1):\n",
    "            B_n[i, :] = R_O.T @ extend_diagonal(self.R_T[i].squeeze(0), self.n+2) @ R_O @ self.S\n",
    "        return B_n\n",
    "\n",
    "    def V_n(self, R, R_O):\n",
    "        # calculate the representation V_n:\n",
    "        return self.M_n.T @ R_O @ R @ R_O.T @ self.M_n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4417a818",
   "metadata": {},
   "source": [
    "### Define the variables: choose $n$ and run the rest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "03fd860d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose the dimensionality of the space\n",
    "n = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e7701a0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "equi_sphere = EquivariantHyperspheres(n)\n",
    "\n",
    "\n",
    "P_n = equi_sphere.p # the regular simplex \n",
    "p_norm = equi_sphere.p_norm # the scalar p in the paper\n",
    "M_n = equi_sphere.M_n # the simplex change-of-basis matrix\n",
    "R_T = equi_sphere.R_T # the rotations from the first simplex vertex p_1 to p_i\n",
    "\n",
    "\n",
    "# get a random vector and sphere in R^n:\n",
    "x = torch.randn(n) # random input data vector\n",
    "c = torch.randn(n) # random sphere center\n",
    "r = torch.randn(1) # radius\n",
    "equi_sphere.embed(x, c, r)\n",
    "\n",
    "X = equi_sphere.X # the embedded input data vector\n",
    "S = equi_sphere.S # the embedded sphere\n",
    "\n",
    "# the initial rotation from c to ||c||p1:\n",
    "R_O = compute_rotation_from_two_points(c.unsqueeze(0)/c.norm(), equi_sphere.p[:,0].unsqueeze(0))[0] \n",
    "R_O_n1n1 = extend_diagonal(R_O, n+1) # the same appended with a one to (n+1)x(n+1)\n",
    "R_O_n2n2 = extend_diagonal(R_O, n+2) # the same appended with ones to (n+2)x(n+2)\n",
    "\n",
    "# a random nxn rotation/reflection:\n",
    "R = random_orthogonal_matrix(n)      \n",
    "R_n1n1 = extend_diagonal(R, n+1)  # the same appended with a one to (n+1)x(n+1)\n",
    "R_n2n2 = extend_diagonal(R, n+2)  # the same appended with ones to (n+2)x(n+2)\n",
    "\n",
    "\n",
    "V_n = equi_sphere.V_n(R_n1n1, R_O_n1n1) # the transformation representation in the output space\n",
    "B_n = equi_sphere.B_n(R_O_n2n2) # the filter bank containing the spheres forming a regular simplex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9f18085b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 0.8407,  0.8445,  1.8319, -1.5569]),\n",
       " tensor([ 0.8407,  0.8445,  1.8319, -1.5569,  2.9811,  1.0000]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c, S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2546d9c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.8407,  0.8445,  1.8319, -1.5569,  2.9811,  1.0000],\n",
       "        [ 1.8062, -1.1950, -1.5476, -0.3377,  2.9811,  1.0000],\n",
       "        [-1.1962,  1.8023, -1.5512, -0.3375,  2.9811,  1.0000],\n",
       "        [-1.8539, -1.8563,  0.4923, -0.2726,  2.9811,  1.0000],\n",
       "        [ 0.4032,  0.4046,  0.7746,  2.5047,  2.9811,  1.0000]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "B_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8f44540c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.5000,  0.7135, -0.4045, -0.4045, -0.4045],\n",
       "        [ 0.5000, -0.4045,  0.7135, -0.4045, -0.4045],\n",
       "        [ 0.5000, -0.4045, -0.4045,  0.7135, -0.4045],\n",
       "        [ 0.5000, -0.4045, -0.4045, -0.4045,  0.7135]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "P_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8bb86608",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.1180)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_norm  # the scalar p in the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8a03e039",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.4472,  0.6382, -0.3618, -0.3618, -0.3618],\n",
       "        [ 0.4472, -0.3618,  0.6382, -0.3618, -0.3618],\n",
       "        [ 0.4472, -0.3618, -0.3618,  0.6382, -0.3618],\n",
       "        [ 0.4472, -0.3618, -0.3618, -0.3618,  0.6382],\n",
       "        [ 0.4472,  0.4472,  0.4472,  0.4472,  0.4472]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M_n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b02acb",
   "metadata": {},
   "source": [
    "### Proof of equivariance (Theorem 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "578605e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-5.2291, -1.4881, -2.3216, -4.2670, -4.2667])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "V_n @ B_n @ X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f375aa1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-5.2291, -1.4881, -2.3216, -4.2670, -4.2667])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "B_n @ R_n2n2 @ X"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28223556",
   "metadata": {},
   "source": [
    "### In-between computations and misc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9918d3d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[     1.1180,     -0.0000,     -0.0000,     -0.0000],\n",
       "        [    -0.0000,      1.1180,     -0.0000,     -0.0000],\n",
       "        [    -0.0000,     -0.0000,      1.1180,     -0.0000],\n",
       "        [    -0.0000,     -0.0000,     -0.0000,      1.1180],\n",
       "        [     0.0000,      0.0000,      0.0000,      0.0000]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M_n @ P_n.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c4f6dc47",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[     2.4863,     -0.5119,     -0.0656,     -1.5973,      0.0000,\n",
       "              0.0000],\n",
       "        [    -0.5161,      2.4855,     -0.0692,     -1.5970,      0.0000,\n",
       "              0.0000],\n",
       "        [    -1.1737,     -1.1732,      1.9743,     -1.5321,      0.0000,\n",
       "              0.0000],\n",
       "        [     1.0833,      1.0878,      2.2566,      1.2451,     -0.0000,\n",
       "              0.0000],\n",
       "        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,\n",
       "              2.2361]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M_n @ B_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "315493d0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[     2.9999,     -0.0000,     -0.0000,      0.0000,      0.0000,\n",
       "              0.0000],\n",
       "        [    -0.0000,      2.9999,     -0.0000,      0.0000,     -0.0000,\n",
       "             -0.0000],\n",
       "        [    -0.0000,     -0.0000,      2.9999,      0.0000,     -0.0000,\n",
       "             -0.0000],\n",
       "        [     0.0000,      0.0000,      0.0000,      2.9999,     -0.0000,\n",
       "             -0.0000],\n",
       "        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,\n",
       "              2.2361]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "R_O_n1n1.T @ M_n @ B_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1e7e9058",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[    -1.5642,      1.5692,      1.0586,     -1.7234,     -0.0000,\n",
       "              0.0000],\n",
       "        [    -0.6657,     -2.3917,      1.5702,     -0.6090,     -0.0000,\n",
       "              0.0000],\n",
       "        [    -1.8822,     -0.8062,     -2.1636,     -0.3547,      0.0000,\n",
       "              0.0000],\n",
       "        [    -1.6022,      0.4089,      0.8558,      2.3522,     -0.0000,\n",
       "             -0.0000],\n",
       "        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,\n",
       "              2.2361]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "R_n1n1 @ R_O_n1n1.T @ M_n @ B_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "942ffeda",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[    -0.2886,      1.5085,      0.2010,     -2.5691,     -0.0000,\n",
       "              0.0000],\n",
       "        [     0.6139,     -2.4506,      0.7131,     -1.4521,      0.0000,\n",
       "              0.0000],\n",
       "        [     0.4519,     -0.4181,     -2.8892,     -0.5223,      0.0000,\n",
       "              0.0000],\n",
       "        [    -2.8870,     -0.7373,     -0.3207,     -0.1337,     -0.0000,\n",
       "             -0.0000],\n",
       "        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,\n",
       "              2.2361]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "R_O_n1n1 @ R_n1n1 @ R_O_n1n1.T @ M_n @ B_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8c581eeb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.9436, -0.9380, -1.0267, -2.0917,  2.9811,  1.0000],\n",
       "        [ 0.4748,  2.2674,  1.0316, -0.8768,  2.9811,  1.0000],\n",
       "        [ 1.3772, -1.6917,  1.5437,  0.2401,  2.9811,  1.0000],\n",
       "        [ 1.2152,  0.3408, -2.0586,  1.1699,  2.9811,  1.0000],\n",
       "        [-2.1237,  0.0216,  0.5099,  1.5585,  2.9811,  1.0000]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M_n.T @ R_O_n1n1 @ R_n1n1 @ R_O_n1n1.T @ M_n @ B_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "080ceca6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0000, 0.0000, 0.0000, 0.0000, 2.2361])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M_n @ torch.ones(n+1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90a38eff",
   "metadata": {},
   "source": [
    "### Numeric instances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b020a673",
   "metadata": {},
   "outputs": [],
   "source": [
    "# n=2:\n",
    "P_2 = 2**(-0.5) * torch.tensor([[1, ((3 ** 0.5) - 1) / 2, -((3 ** 0.5) + 1) / 2],\n",
    "                                 [1, -((3 ** 0.5) + 1) / 2, ((3 ** 0.5) - 1) / 2]])\n",
    "\n",
    "p_norm_2 = 1.5**0.5\n",
    "\n",
    "\n",
    "\n",
    "M_2 = 3**(-0.5) * torch.tensor([[1, (3**0.5 - 1) / 2, -(3**0.5 + 1) / 2],\n",
    "                                [1, -(3**0.5 + 1) / 2, (3**0.5 - 1) / 2],\n",
    "                                [1, 1, 1]])\n",
    "\n",
    "\n",
    "\n",
    "# n=3:\n",
    "P_3 = torch.tensor([[ 1 / (3 ** 0.5),  1 / (3 ** 0.5), -1 / (3 ** 0.5), -1 / (3 ** 0.5)],\n",
    "                    [ 1 / (3 ** 0.5), -1 / (3 ** 0.5),  1 / (3 ** 0.5), -1 / (3 ** 0.5)],\n",
    "                    [ 1 / (3 ** 0.5), -1 / (3 ** 0.5), -1 / (3 ** 0.5),  1 / (3 ** 0.5)]])\n",
    "\n",
    "p_norm_3 = 2 * 3**(-0.5)\n",
    "\n",
    "M_3 = 0.5 * torch.tensor([[1, 1, -1, -1],\n",
    "                          [1, -1, 1, -1],\n",
    "                          [1, -1, -1, 1],\n",
    "                          [1, 1, 1, 1]])\n",
    "\n",
    "\n",
    "# n=4:\n",
    "P_4 = 1/2 * torch.tensor([[ 1,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],\n",
    "                          [ 1, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],\n",
    "                          [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4],\n",
    "                          [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4]])\n",
    "\n",
    "p_norm_4 = 5**0.5 / 2\n",
    "\n",
    "M_4 = 1/5**0.5 * torch.tensor([[ 1,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],\n",
    "                               [ 1, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],\n",
    "                               [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4],\n",
    "                               [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4],\n",
    "                               [ 1, 1, 1, 1, 1]])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
