Learning Deep O($n$)-Equivariant Hyperspheres¶
This notebook is meant to be used to verify the theoretical results from the paper, such as the equivariance of the proposed $n \text{D}$ neurons.
The complete implementation of our method, as well as experiment-related scripts, will be made publicly available upon acceptance.
In [1]:
import torch
torch.set_printoptions(precision=4, sci_mode=False)
Helper functions¶
In [2]:
def extend_diagonal(R, n):
# given an mxm matrix,
# append its diagonal with ones to an nxn matrix
m, _ = R.shape
assert m <= n
I = torch.eye(n)
I[:m, :m] = R
return I
def compute_rotation_from_two_points(p, q):
'''
A reflection method:
assuming ||p|| == ||q||
f(A, u) = A - 2 u (u^T S)/||u||^2
S = f(I_n, p+q)
R = f(S, q)
Args:
p, q - torch.Tensor - two nD points, necessarily with ||p|| == ||q||
Return:
R - DxD rotation matrix such that R p = q
'''
assert len(p.shape) == 2 and p.shape == q.shape
a = torch.abs(p.norm(dim=1, keepdim=True).pow(2) - q.norm(dim=1, keepdim=True).pow(2)).max()
assert a < 1e-5, 'Such a rotation doesn\'t exist: ||p|| must be equal to ||q||, '+str(a)
B, D = p.shape
def reflection(S, u):
# reflection of S on hyperplane u:
# S can be a matrix; S and u must have the same number of rows.
assert len(S) == len(u) and S.shape[-1] == u.shape[-1]
v = torch.matmul(u.unsqueeze(1), S) # (Bx1xD)
v = v.squeeze(1) / u.norm(dim=1, keepdim=True)**2 # (BxD) / (Bx1) = (BxD)
M = S - 2 * torch.matmul(u.unsqueeze(-1), v.unsqueeze(1)) # the matmul performs the outer product of u and v
return M
S = reflection( torch.eye(D).repeat(B, 1, 1).to(p.device), p+q ) # S @ p = -q, S @ q = -p
R = reflection(S, q) # R @ p = q, R.T @ q = p
return R
def random_orthogonal_matrix(n):
# generate a random nxn matrix:
random_matrix = torch.rand(n, n)
# use the QR decomposition to orthogonalize the matrix:
Q, _ = torch.linalg.qr(random_matrix)
return Q
Main class¶
In [3]:
class EquivariantHyperspheres:
def __init__(self, n):
self.n = n
self.p = torch.zeros(n + 1, n)
# calculate vertices of the regular n-simplex:
self.kappa = - (1 + (n + 1)**(1/2)) / (n**(3/2))
self.mu = (1 + 1 / n)**(1/2)
self.p[0, :] = 1 / n**(1/2)
for i in range(1, n + 1):
self.p[i, :] = self.kappa + self.mu * torch.eye(n)[i - 1]
# calculate the simplex change-of-basis matrix M_n:
self.M_n = torch.zeros(n+1, n+1)
for i in range(n + 1):
p_extended = torch.cat((self.p[i, :], torch.tensor([n**(-1/2)])))
p_norm = torch.norm(p_extended)
self.M_n[i, :] = (1 / p_norm) * p_extended
self.R_T = [torch.eye(n).unsqueeze(0)]
self.R_T += [compute_rotation_from_two_points(self.p[0:1], p.unsqueeze(0)) for p in self.p[1:]]
self.R_T = torch.stack(self.R_T)
self.p_norm = p_norm
self.p = self.p.T # n x n+1
self.M_n = self.M_n.T # n+1 x n+1
def embed(self, x, c, r):
# embed data vector and sphere into R^(n+2):
self.X = torch.cat((x, torch.tensor([-1, -0.5 * torch.norm(x) ** 2])))
self.S = torch.cat((c, torch.tensor([0.5 * (torch.norm(c) ** 2 - r ** 2), 1])))
def B_n(self, R_O):
# construct the filter bank B_n:
B_n = torch.zeros(self.n+1, self.n+2)
for i in range(self.n + 1):
B_n[i, :] = R_O.T @ extend_diagonal(self.R_T[i].squeeze(0), self.n+2) @ R_O @ self.S
return B_n
def V_n(self, R, R_O):
# calculate the representation V_n:
return self.M_n.T @ R_O @ R @ R_O.T @ self.M_n
Define the variables: choose $n$ and run the rest¶
In [4]:
# choose the dimensionality of the space
n = 4
In [5]:
equi_sphere = EquivariantHyperspheres(n)
P_n = equi_sphere.p # the regular simplex
p_norm = equi_sphere.p_norm # the scalar p in the paper
M_n = equi_sphere.M_n # the simplex change-of-basis matrix
R_T = equi_sphere.R_T # the rotations from the first simplex vertex p_1 to p_i
# get a random vector and sphere in R^n:
x = torch.randn(n) # random input data vector
c = torch.randn(n) # random sphere center
r = torch.randn(1) # radius
equi_sphere.embed(x, c, r)
X = equi_sphere.X # the embedded input data vector
S = equi_sphere.S # the embedded sphere
# the initial rotation from c to ||c||p1:
R_O = compute_rotation_from_two_points(c.unsqueeze(0)/c.norm(), equi_sphere.p[:,0].unsqueeze(0))[0]
R_O_n1n1 = extend_diagonal(R_O, n+1) # the same appended with a one to (n+1)x(n+1)
R_O_n2n2 = extend_diagonal(R_O, n+2) # the same appended with ones to (n+2)x(n+2)
# a random nxn rotation/reflection:
R = random_orthogonal_matrix(n)
R_n1n1 = extend_diagonal(R, n+1) # the same appended with a one to (n+1)x(n+1)
R_n2n2 = extend_diagonal(R, n+2) # the same appended with ones to (n+2)x(n+2)
V_n = equi_sphere.V_n(R_n1n1, R_O_n1n1) # the transformation representation in the output space
B_n = equi_sphere.B_n(R_O_n2n2) # the filter bank containing the spheres forming a regular simplex
In [6]:
c, S
Out[6]:
(tensor([ 0.8407, 0.8445, 1.8319, -1.5569]), tensor([ 0.8407, 0.8445, 1.8319, -1.5569, 2.9811, 1.0000]))
In [7]:
B_n
Out[7]:
tensor([[ 0.8407, 0.8445, 1.8319, -1.5569, 2.9811, 1.0000], [ 1.8062, -1.1950, -1.5476, -0.3377, 2.9811, 1.0000], [-1.1962, 1.8023, -1.5512, -0.3375, 2.9811, 1.0000], [-1.8539, -1.8563, 0.4923, -0.2726, 2.9811, 1.0000], [ 0.4032, 0.4046, 0.7746, 2.5047, 2.9811, 1.0000]])
In [8]:
P_n
Out[8]:
tensor([[ 0.5000, 0.7135, -0.4045, -0.4045, -0.4045], [ 0.5000, -0.4045, 0.7135, -0.4045, -0.4045], [ 0.5000, -0.4045, -0.4045, 0.7135, -0.4045], [ 0.5000, -0.4045, -0.4045, -0.4045, 0.7135]])
In [9]:
p_norm # the scalar p in the paper
Out[9]:
tensor(1.1180)
In [10]:
M_n
Out[10]:
tensor([[ 0.4472, 0.6382, -0.3618, -0.3618, -0.3618], [ 0.4472, -0.3618, 0.6382, -0.3618, -0.3618], [ 0.4472, -0.3618, -0.3618, 0.6382, -0.3618], [ 0.4472, -0.3618, -0.3618, -0.3618, 0.6382], [ 0.4472, 0.4472, 0.4472, 0.4472, 0.4472]])
Proof of equivariance (Theorem 4)¶
In [11]:
V_n @ B_n @ X
Out[11]:
tensor([-5.2291, -1.4881, -2.3216, -4.2670, -4.2667])
In [12]:
B_n @ R_n2n2 @ X
Out[12]:
tensor([-5.2291, -1.4881, -2.3216, -4.2670, -4.2667])
In-between computations and misc¶
In [13]:
M_n @ P_n.T
Out[13]:
tensor([[ 1.1180, -0.0000, -0.0000, -0.0000], [ -0.0000, 1.1180, -0.0000, -0.0000], [ -0.0000, -0.0000, 1.1180, -0.0000], [ -0.0000, -0.0000, -0.0000, 1.1180], [ 0.0000, 0.0000, 0.0000, 0.0000]])
In [14]:
M_n @ B_n
Out[14]:
tensor([[ 2.4863, -0.5119, -0.0656, -1.5973, 0.0000, 0.0000], [ -0.5161, 2.4855, -0.0692, -1.5970, 0.0000, 0.0000], [ -1.1737, -1.1732, 1.9743, -1.5321, 0.0000, 0.0000], [ 1.0833, 1.0878, 2.2566, 1.2451, -0.0000, 0.0000], [ 0.0000, -0.0000, 0.0000, -0.0000, 6.6659, 2.2361]])
In [15]:
R_O_n1n1.T @ M_n @ B_n
Out[15]:
tensor([[ 2.9999, -0.0000, -0.0000, 0.0000, 0.0000, 0.0000], [ -0.0000, 2.9999, -0.0000, 0.0000, -0.0000, -0.0000], [ -0.0000, -0.0000, 2.9999, 0.0000, -0.0000, -0.0000], [ 0.0000, 0.0000, 0.0000, 2.9999, -0.0000, -0.0000], [ 0.0000, -0.0000, 0.0000, -0.0000, 6.6659, 2.2361]])
In [16]:
R_n1n1 @ R_O_n1n1.T @ M_n @ B_n
Out[16]:
tensor([[ -1.5642, 1.5692, 1.0586, -1.7234, -0.0000, 0.0000], [ -0.6657, -2.3917, 1.5702, -0.6090, -0.0000, 0.0000], [ -1.8822, -0.8062, -2.1636, -0.3547, 0.0000, 0.0000], [ -1.6022, 0.4089, 0.8558, 2.3522, -0.0000, -0.0000], [ 0.0000, -0.0000, 0.0000, -0.0000, 6.6659, 2.2361]])
In [17]:
R_O_n1n1 @ R_n1n1 @ R_O_n1n1.T @ M_n @ B_n
Out[17]:
tensor([[ -0.2886, 1.5085, 0.2010, -2.5691, -0.0000, 0.0000], [ 0.6139, -2.4506, 0.7131, -1.4521, 0.0000, 0.0000], [ 0.4519, -0.4181, -2.8892, -0.5223, 0.0000, 0.0000], [ -2.8870, -0.7373, -0.3207, -0.1337, -0.0000, -0.0000], [ 0.0000, -0.0000, 0.0000, -0.0000, 6.6659, 2.2361]])
In [18]:
M_n.T @ R_O_n1n1 @ R_n1n1 @ R_O_n1n1.T @ M_n @ B_n
Out[18]:
tensor([[-0.9436, -0.9380, -1.0267, -2.0917, 2.9811, 1.0000], [ 0.4748, 2.2674, 1.0316, -0.8768, 2.9811, 1.0000], [ 1.3772, -1.6917, 1.5437, 0.2401, 2.9811, 1.0000], [ 1.2152, 0.3408, -2.0586, 1.1699, 2.9811, 1.0000], [-2.1237, 0.0216, 0.5099, 1.5585, 2.9811, 1.0000]])
In [19]:
M_n @ torch.ones(n+1)
Out[19]:
tensor([0.0000, 0.0000, 0.0000, 0.0000, 2.2361])
Numeric instances¶
In [20]:
# n=2:
P_2 = 2**(-0.5) * torch.tensor([[1, ((3 ** 0.5) - 1) / 2, -((3 ** 0.5) + 1) / 2],
[1, -((3 ** 0.5) + 1) / 2, ((3 ** 0.5) - 1) / 2]])
p_norm_2 = 1.5**0.5
M_2 = 3**(-0.5) * torch.tensor([[1, (3**0.5 - 1) / 2, -(3**0.5 + 1) / 2],
[1, -(3**0.5 + 1) / 2, (3**0.5 - 1) / 2],
[1, 1, 1]])
# n=3:
P_3 = torch.tensor([[ 1 / (3 ** 0.5), 1 / (3 ** 0.5), -1 / (3 ** 0.5), -1 / (3 ** 0.5)],
[ 1 / (3 ** 0.5), -1 / (3 ** 0.5), 1 / (3 ** 0.5), -1 / (3 ** 0.5)],
[ 1 / (3 ** 0.5), -1 / (3 ** 0.5), -1 / (3 ** 0.5), 1 / (3 ** 0.5)]])
p_norm_3 = 2 * 3**(-0.5)
M_3 = 0.5 * torch.tensor([[1, 1, -1, -1],
[1, -1, 1, -1],
[1, -1, -1, 1],
[1, 1, 1, 1]])
# n=4:
P_4 = 1/2 * torch.tensor([[ 1, (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
[ 1, -(5**0.5 + 1)/4, (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
[ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4],
[ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, (3* 5**0.5 - 1)/4]])
p_norm_4 = 5**0.5 / 2
M_4 = 1/5**0.5 * torch.tensor([[ 1, (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
[ 1, -(5**0.5 + 1)/4, (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
[ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4],
[ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, (3* 5**0.5 - 1)/4],
[ 1, 1, 1, 1, 1]])