Learning Deep O($n$)-Equivariant Hyperspheres¶

This notebook is meant to be used to verify the theoretical results from the paper, such as the equivariance of the proposed $n \text{D}$ neurons.

The complete implementation of our method, as well as experiment-related scripts, will be made publicly available upon acceptance.

In [1]:
import torch
torch.set_printoptions(precision=4, sci_mode=False)

Helper functions¶

In [2]:
def extend_diagonal(R, n):
    # given an mxm matrix,
    # append its diagonal with ones to an nxn matrix
    m, _ = R.shape
    assert m <= n
    I = torch.eye(n)    
    I[:m, :m] = R
    return I


def compute_rotation_from_two_points(p, q):
    ''' 
    A reflection method:
        assuming ||p|| == ||q||
        f(A, u) = A - 2 u (u^T S)/||u||^2
        S = f(I_n, p+q)
        R = f(S, q)

    Args: 
        p, q - torch.Tensor - two nD points, necessarily with ||p|| == ||q||

    Return:
        R - DxD rotation matrix such that R p = q
    '''    
    assert len(p.shape) == 2 and p.shape == q.shape
    a = torch.abs(p.norm(dim=1, keepdim=True).pow(2) - q.norm(dim=1, keepdim=True).pow(2)).max()

    assert a < 1e-5, 'Such a rotation doesn\'t exist: ||p|| must be equal to ||q||, '+str(a)
    B, D = p.shape 

    def reflection(S, u):   
        # reflection of S on hyperplane u:
        # S can be a matrix; S and u must have the same number of rows.
        assert len(S) == len(u) and S.shape[-1] == u.shape[-1]

        v = torch.matmul(u.unsqueeze(1), S) # (Bx1xD)
        v = v.squeeze(1) / u.norm(dim=1, keepdim=True)**2 # (BxD) / (Bx1) = (BxD)
        M = S - 2 * torch.matmul(u.unsqueeze(-1), v.unsqueeze(1)) # the matmul performs the outer product of u and v            
        return M

    S = reflection( torch.eye(D).repeat(B, 1, 1).to(p.device), p+q )  # S @ p = -q, S @ q = -p
    R = reflection(S, q) # R @ p = q, R.T @ q = p            

    return R



def random_orthogonal_matrix(n):
    # generate a random nxn matrix:
    random_matrix = torch.rand(n, n)
    
    # use the QR decomposition to orthogonalize the matrix:
    Q, _ = torch.linalg.qr(random_matrix)
    
    return Q

Main class¶

In [3]:
class EquivariantHyperspheres:

    def __init__(self, n):
        self.n = n
        self.p = torch.zeros(n + 1, n)

        # calculate vertices of the regular n-simplex:
        self.kappa = - (1 + (n + 1)**(1/2)) / (n**(3/2))
        self.mu = (1 + 1 / n)**(1/2)
        self.p[0, :] = 1 / n**(1/2)
        for i in range(1, n + 1):
            self.p[i, :] = self.kappa + self.mu * torch.eye(n)[i - 1]

        # calculate the simplex change-of-basis matrix M_n:
        self.M_n = torch.zeros(n+1, n+1)
        for i in range(n + 1):
            p_extended = torch.cat((self.p[i, :], torch.tensor([n**(-1/2)])))
            p_norm = torch.norm(p_extended)
            self.M_n[i, :] = (1 / p_norm) * p_extended
                
        self.R_T = [torch.eye(n).unsqueeze(0)]
        self.R_T += [compute_rotation_from_two_points(self.p[0:1], p.unsqueeze(0)) for p in self.p[1:]]
        self.R_T = torch.stack(self.R_T)
        
        self.p_norm = p_norm
        self.p = self.p.T         # n x n+1
        self.M_n = self.M_n.T     # n+1 x n+1

    def embed(self, x, c, r):
        # embed data vector and sphere into R^(n+2):
        self.X = torch.cat((x, torch.tensor([-1, -0.5 * torch.norm(x) ** 2])))
        self.S = torch.cat((c, torch.tensor([0.5 * (torch.norm(c) ** 2 - r ** 2), 1])))

    def B_n(self, R_O):
        # construct the filter bank B_n:
        B_n = torch.zeros(self.n+1, self.n+2)
        for i in range(self.n + 1):
            B_n[i, :] = R_O.T @ extend_diagonal(self.R_T[i].squeeze(0), self.n+2) @ R_O @ self.S
        return B_n

    def V_n(self, R, R_O):
        # calculate the representation V_n:
        return self.M_n.T @ R_O @ R @ R_O.T @ self.M_n

Define the variables: choose $n$ and run the rest¶

In [4]:
# choose the dimensionality of the space
n = 4
In [5]:
equi_sphere = EquivariantHyperspheres(n)


P_n = equi_sphere.p # the regular simplex 
p_norm = equi_sphere.p_norm # the scalar p in the paper
M_n = equi_sphere.M_n # the simplex change-of-basis matrix
R_T = equi_sphere.R_T # the rotations from the first simplex vertex p_1 to p_i


# get a random vector and sphere in R^n:
x = torch.randn(n) # random input data vector
c = torch.randn(n) # random sphere center
r = torch.randn(1) # radius
equi_sphere.embed(x, c, r)

X = equi_sphere.X # the embedded input data vector
S = equi_sphere.S # the embedded sphere

# the initial rotation from c to ||c||p1:
R_O = compute_rotation_from_two_points(c.unsqueeze(0)/c.norm(), equi_sphere.p[:,0].unsqueeze(0))[0] 
R_O_n1n1 = extend_diagonal(R_O, n+1) # the same appended with a one to (n+1)x(n+1)
R_O_n2n2 = extend_diagonal(R_O, n+2) # the same appended with ones to (n+2)x(n+2)

# a random nxn rotation/reflection:
R = random_orthogonal_matrix(n)      
R_n1n1 = extend_diagonal(R, n+1)  # the same appended with a one to (n+1)x(n+1)
R_n2n2 = extend_diagonal(R, n+2)  # the same appended with ones to (n+2)x(n+2)


V_n = equi_sphere.V_n(R_n1n1, R_O_n1n1) # the transformation representation in the output space
B_n = equi_sphere.B_n(R_O_n2n2) # the filter bank containing the spheres forming a regular simplex
In [6]:
c, S
Out[6]:
(tensor([ 0.8407,  0.8445,  1.8319, -1.5569]),
 tensor([ 0.8407,  0.8445,  1.8319, -1.5569,  2.9811,  1.0000]))
In [7]:
B_n
Out[7]:
tensor([[ 0.8407,  0.8445,  1.8319, -1.5569,  2.9811,  1.0000],
        [ 1.8062, -1.1950, -1.5476, -0.3377,  2.9811,  1.0000],
        [-1.1962,  1.8023, -1.5512, -0.3375,  2.9811,  1.0000],
        [-1.8539, -1.8563,  0.4923, -0.2726,  2.9811,  1.0000],
        [ 0.4032,  0.4046,  0.7746,  2.5047,  2.9811,  1.0000]])
In [8]:
P_n
Out[8]:
tensor([[ 0.5000,  0.7135, -0.4045, -0.4045, -0.4045],
        [ 0.5000, -0.4045,  0.7135, -0.4045, -0.4045],
        [ 0.5000, -0.4045, -0.4045,  0.7135, -0.4045],
        [ 0.5000, -0.4045, -0.4045, -0.4045,  0.7135]])
In [9]:
p_norm  # the scalar p in the paper
Out[9]:
tensor(1.1180)
In [10]:
M_n
Out[10]:
tensor([[ 0.4472,  0.6382, -0.3618, -0.3618, -0.3618],
        [ 0.4472, -0.3618,  0.6382, -0.3618, -0.3618],
        [ 0.4472, -0.3618, -0.3618,  0.6382, -0.3618],
        [ 0.4472, -0.3618, -0.3618, -0.3618,  0.6382],
        [ 0.4472,  0.4472,  0.4472,  0.4472,  0.4472]])

Proof of equivariance (Theorem 4)¶

In [11]:
V_n @ B_n @ X
Out[11]:
tensor([-5.2291, -1.4881, -2.3216, -4.2670, -4.2667])
In [12]:
B_n @ R_n2n2 @ X
Out[12]:
tensor([-5.2291, -1.4881, -2.3216, -4.2670, -4.2667])

In-between computations and misc¶

In [13]:
M_n @ P_n.T
Out[13]:
tensor([[     1.1180,     -0.0000,     -0.0000,     -0.0000],
        [    -0.0000,      1.1180,     -0.0000,     -0.0000],
        [    -0.0000,     -0.0000,      1.1180,     -0.0000],
        [    -0.0000,     -0.0000,     -0.0000,      1.1180],
        [     0.0000,      0.0000,      0.0000,      0.0000]])
In [14]:
M_n @ B_n
Out[14]:
tensor([[     2.4863,     -0.5119,     -0.0656,     -1.5973,      0.0000,
              0.0000],
        [    -0.5161,      2.4855,     -0.0692,     -1.5970,      0.0000,
              0.0000],
        [    -1.1737,     -1.1732,      1.9743,     -1.5321,      0.0000,
              0.0000],
        [     1.0833,      1.0878,      2.2566,      1.2451,     -0.0000,
              0.0000],
        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,
              2.2361]])
In [15]:
R_O_n1n1.T @ M_n @ B_n
Out[15]:
tensor([[     2.9999,     -0.0000,     -0.0000,      0.0000,      0.0000,
              0.0000],
        [    -0.0000,      2.9999,     -0.0000,      0.0000,     -0.0000,
             -0.0000],
        [    -0.0000,     -0.0000,      2.9999,      0.0000,     -0.0000,
             -0.0000],
        [     0.0000,      0.0000,      0.0000,      2.9999,     -0.0000,
             -0.0000],
        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,
              2.2361]])
In [16]:
R_n1n1 @ R_O_n1n1.T @ M_n @ B_n
Out[16]:
tensor([[    -1.5642,      1.5692,      1.0586,     -1.7234,     -0.0000,
              0.0000],
        [    -0.6657,     -2.3917,      1.5702,     -0.6090,     -0.0000,
              0.0000],
        [    -1.8822,     -0.8062,     -2.1636,     -0.3547,      0.0000,
              0.0000],
        [    -1.6022,      0.4089,      0.8558,      2.3522,     -0.0000,
             -0.0000],
        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,
              2.2361]])
In [17]:
R_O_n1n1 @ R_n1n1 @ R_O_n1n1.T @ M_n @ B_n
Out[17]:
tensor([[    -0.2886,      1.5085,      0.2010,     -2.5691,     -0.0000,
              0.0000],
        [     0.6139,     -2.4506,      0.7131,     -1.4521,      0.0000,
              0.0000],
        [     0.4519,     -0.4181,     -2.8892,     -0.5223,      0.0000,
              0.0000],
        [    -2.8870,     -0.7373,     -0.3207,     -0.1337,     -0.0000,
             -0.0000],
        [     0.0000,     -0.0000,      0.0000,     -0.0000,      6.6659,
              2.2361]])
In [18]:
M_n.T @ R_O_n1n1 @ R_n1n1 @ R_O_n1n1.T @ M_n @ B_n
Out[18]:
tensor([[-0.9436, -0.9380, -1.0267, -2.0917,  2.9811,  1.0000],
        [ 0.4748,  2.2674,  1.0316, -0.8768,  2.9811,  1.0000],
        [ 1.3772, -1.6917,  1.5437,  0.2401,  2.9811,  1.0000],
        [ 1.2152,  0.3408, -2.0586,  1.1699,  2.9811,  1.0000],
        [-2.1237,  0.0216,  0.5099,  1.5585,  2.9811,  1.0000]])
In [19]:
M_n @ torch.ones(n+1)
Out[19]:
tensor([0.0000, 0.0000, 0.0000, 0.0000, 2.2361])

Numeric instances¶

In [20]:
# n=2:
P_2 = 2**(-0.5) * torch.tensor([[1, ((3 ** 0.5) - 1) / 2, -((3 ** 0.5) + 1) / 2],
                                 [1, -((3 ** 0.5) + 1) / 2, ((3 ** 0.5) - 1) / 2]])

p_norm_2 = 1.5**0.5



M_2 = 3**(-0.5) * torch.tensor([[1, (3**0.5 - 1) / 2, -(3**0.5 + 1) / 2],
                                [1, -(3**0.5 + 1) / 2, (3**0.5 - 1) / 2],
                                [1, 1, 1]])



# n=3:
P_3 = torch.tensor([[ 1 / (3 ** 0.5),  1 / (3 ** 0.5), -1 / (3 ** 0.5), -1 / (3 ** 0.5)],
                    [ 1 / (3 ** 0.5), -1 / (3 ** 0.5),  1 / (3 ** 0.5), -1 / (3 ** 0.5)],
                    [ 1 / (3 ** 0.5), -1 / (3 ** 0.5), -1 / (3 ** 0.5),  1 / (3 ** 0.5)]])

p_norm_3 = 2 * 3**(-0.5)

M_3 = 0.5 * torch.tensor([[1, 1, -1, -1],
                          [1, -1, 1, -1],
                          [1, -1, -1, 1],
                          [1, 1, 1, 1]])


# n=4:
P_4 = 1/2 * torch.tensor([[ 1,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
                          [ 1, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
                          [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4],
                          [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4]])

p_norm_4 = 5**0.5 / 2

M_4 = 1/5**0.5 * torch.tensor([[ 1,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
                               [ 1, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4],
                               [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4, -(5**0.5 + 1)/4],
                               [ 1, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4, -(5**0.5 + 1)/4,  (3* 5**0.5 - 1)/4],
                               [ 1, 1, 1, 1, 1]])