%load_ext autoreload
%autoreload 2
%matplotlib inline
SEED = 239
import numpy as np
np.random.seed(SEED)
np.set_printoptions(precision=4)
import torch
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.collections as mcol
import matplotlib.transforms as mtransforms
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.transform import Rotation as R
from scipy.special import softmax
from tqdm import tqdm
from models import PointCMLP, SteerableModel
from utils import (get_tetris_data, plot_shapes, score, build_mlgp,
construct_filter_banks, unembed_points,
transform_parameters, build_steerable_model,
random_axis_angle, torch_rotation_matrix, entropy)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.__version__, device
('1.8.1+cu102', device(type='cuda'))
in the following steps:
Step 1. Train the ancestor MLGP.
Step 2. Transform the hidden unit parameters into filter banks.
Step 3. Fix the learned parameters and add the interpolation coefficients $v^k$ as learnable parameters to fulfill the steerability constraint $\rightarrow$ Steerable spherical classifier.
Xtrain, Ytrain, shape_names = get_tetris_data()
for shape, name in zip(Xtrain, shape_names):
print(name, '\n', shape, '\n\n')
chiral_shape_1 tensor([[0., 0., 0.], [0., 0., 1.], [1., 0., 0.], [1., 1., 0.]]) chiral_shape_2 tensor([[ 0., 0., 0.], [ 0., 0., 1.], [ 1., 0., 0.], [ 1., -1., 0.]]) square tensor([[0., 0., 0.], [1., 0., 0.], [0., 1., 0.], [1., 1., 0.]]) line tensor([[0., 0., 0.], [0., 0., 1.], [0., 0., 2.], [0., 0., 3.]]) corner tensor([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.]]) L tensor([[0., 0., 0.], [0., 0., 1.], [0., 0., 2.], [0., 1., 0.]]) T tensor([[0., 0., 0.], [0., 0., 1.], [0., 0., 2.], [0., 1., 1.]]) zigzag tensor([[0., 0., 0.], [1., 0., 0.], [1., 1., 0.], [2., 1., 0.]])
plot_shapes(Xtrain, shape_names, offset=1.5)
N_GEOMETRIC_NEURONS = 5
OUTPUT_DIM = len(set(Ytrain.numpy()))
# set the seed here:
torch.manual_seed(SEED)
# instantiate the model:
model = build_mlgp(input_shape=Xtrain.shape[1:], output_dim=OUTPUT_DIM, hidden_layer_sizes=[N_GEOMETRIC_NEURONS], bias=False)
print(model)
print('total number of trainable parameters:', sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())]))
print()
model = model.to(device)
Xtrain, Ytrain = Xtrain.float().to(device), Ytrain.to(device)
# define the loss and optimizer:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 2000
# train the model:
for i in range(epochs):
y_pred = model(Xtrain)
loss = criterion(y_pred, Ytrain)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = score(y_pred.detach(), Ytrain)
if i % 500 == 0:
print('epoch: %d, loss: %.3f, acc: %.3f' % (i, loss.item(), acc))
print('epoch: %d, loss: %.3f, acc: %.3f' % (i, loss.item(), acc))
model: MLGP PointCMLP( (hidden_layers): ModuleList( (0): Linear(in_features=20, out_features=5, bias=False) ) (out_layer): Linear(in_features=7, out_features=8, bias=False) ) total number of trainable parameters: 156 epoch: 0, loss: 2.064, acc: 0.375 epoch: 500, loss: 0.318, acc: 1.000 epoch: 1000, loss: 0.054, acc: 1.000 epoch: 1500, loss: 0.019, acc: 1.000 epoch: 1999, loss: 0.009, acc: 1.000
The main function in this step is construct_filter_banks
:
To demonstrate how it works, we first extract the ancestor model hidden layer (i.e., geometric neurons) spheres, and then take one of them to form a filter bank:
# extract the spheres from the ancestor model:
original_state_dict = model.state_dict()
# get the geometric neuron spheres:
hidden_name = 'hidden_layers.0.weight'
hidden_spheres = original_state_dict[hidden_name] # (n_geometric_neurons x N_points*5)
# each sphere is a parameter vector of length 5;
# each geometric neuron contains a number of spheres corresponding to the number of input points
# in the point set
# reshape to (n_geometric_neurons x N_points x 5):
hidden_spheres_numpy = hidden_spheres.detach().cpu().numpy().reshape(len(hidden_spheres), -1, 5)
print('hidden_spheres_numpy.shape:', hidden_spheres_numpy.shape)
# e.g., select the third sphere from the second geometric neuron:
one_sphere = hidden_spheres_numpy[1,2,:]
print('\nS_tilde_k = ', one_sphere)
# construct a filter bank for this sphere:
init_rotation, filter_bank = construct_filter_banks(one_sphere, return_init_rotations=True)
print('\nR_O^k =\n', init_rotation, '\n\nB(S_tilde_k) =\n', filter_bank)
hidden_spheres_numpy.shape: (5, 4, 5) S_tilde_k = [-0.6873 0.6743 -0.2568 -0.4767 -0.417 ] R_O^k = [[[ 0.408 -0.6725 -0.6174] [ 0.9052 0.2101 0.3693] [-0.1187 -0.7096 0.6945]]] B(S_tilde_k) = [[-0.6873 0.6743 -0.2568 -0.4767 -0.417 ] [ 0.2178 0.0995 0.9673 -0.4767 -0.417 ] [-0.3543 -0.9161 -0.1682 -0.4767 -0.417 ] [ 0.8238 0.1422 -0.5423 -0.4767 -0.417 ]]
We normalize, i.e., unembed, the resulting spheres to get their Euclidean $\mathbb{R}^3$ representation.
By unembedding the 5-vectors, we will get the first three elements that represent the Euclidean coordinates of the sphere center:
# the centers of the filter bank spheres:
centers = unembed_points(filter_bank)
print('\nthe four sphere centers:\n', centers)
fig = plt.figure(1, figsize=(7,7))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2], s=100)
plt.title('A regular tetrahedron formed by the centers of the $B(S_k)$ spheres')
plt.show()
the four sphere centers: [[ 1.6482 -1.6171 0.6159] [-0.5223 -0.2387 -2.3197] [ 0.8498 2.1969 0.4033] [-1.9756 -0.3411 1.3006]]
Step 2 is wrapped into the function transform_parameters
:
takes in the trained ancestor MLGP model;
transforms its parameters --- uses the hidden layer spheres to create the filter banks and keeps the output (classification) layer the same;
returns the initial rotations $R_O^k$, the filter banks $B(\tilde{S}_k)$, and the ancestor model output layer parameters.
transformed_parameters = transform_parameters(model) # used in the experiments further down
Step 3 is wrapped into the build_steerable_model
function:
creates a steerable model with learnable interpolation coefficients $v^k$ according to the constraint (13) in the paper;
the rest of the model parameters are set to be the ones obtained in Step 2 (the filter banks and the unchanged output layer) and are fixed (not updated).
# choose initial model parameters:
init_axis_angle = random_axis_angle()
# use the initial parameters and the transformed ancestor model parameters obtained in Step 2
# to build a steerable spherical classifier:
steerable_model = build_steerable_model(input_shape=Xtrain.shape[1:],
output_dim=OUTPUT_DIM,
hidden_layer_sizes=[N_GEOMETRIC_NEURONS],
init_axis_angle=init_axis_angle,
transformed_parameters=transformed_parameters,
print_hidden_layer_output=False).to(device)
print(steerable_model)
print('total number of trainable parameters:', \
sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, steerable_model.parameters())]))
print('\ninit_axis_angle:\n', init_axis_angle)
print('\nsteerable_model.axis_angle:\n', steerable_model.axis_angle)
SteerableModel( (hidden_layers): ModuleList( (0): Linear(in_features=80, out_features=5, bias=False) ) (out_layer): Linear(in_features=7, out_features=8, bias=False) ) total number of trainable parameters: 3 init_axis_angle: [-0.9549 -0.4905 1.5045] steerable_model.axis_angle: Parameter containing: tensor([-0.9549, -0.4905, 1.5045], device='cuda:0', dtype=torch.float64, requires_grad=True)
(used in the experiments further down)
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
model.hidden_layers[0].register_forward_hook(get_activation('hidden_layer'))
output = model(Xtrain)
# the ground truth hidden activations -- the ancestor MLGP hidden layer output:
gt_hidden_activations = activation['hidden_layer'].detach().cpu().numpy()
# the ground truth output activations -- the ancestor MLGP model output:
gt_outs = output.detach().cpu().numpy()
# gt_hidden_activations, gt_outs
np.random.seed(SEED)
torch.manual_seed(SEED)
n_trials = 1000
# the parameter of additive uniform noise to apply to the rotated shapes:
distortions = [0.0, 0.05, 0.1, 0.2, 0.3, 0.5]
init_axis_angles = []
accs = dict() # classification accuracies for the perturbed rotated shapes
dists = dict() # L1 distances to the ground truth hidden activations
# the same for the ancestor:
ancestor_accs = dict()
ancestor_dists = dict()
for distortion in distortions:
accs[distortion] = []
dists[distortion] = []
ancestor_accs[distortion] = []
ancestor_dists[distortion] = []
print('\ndistortion:', distortion)
for n in range(n_trials):
# construct a random ground truth rotation:
init_axis_angle = random_axis_angle()
init_axis_angles.append(init_axis_angle)
gt_rotation = torch_rotation_matrix(init_axis_angle).to(device).float()
# rotate the shapes with the ground truth:
test_data = Xtrain.reshape(-1, 3) @ gt_rotation.T
test_data = test_data.reshape(Xtrain.shape)
test_label = Ytrain
# add uniform noise to the transformed shapes:
noise = distortion * (2 * torch.rand(test_data.shape).to(device) - 1)
test_data += noise
# construct the steerable model with the initial axis-angle parameters:
steerable_model = build_steerable_model(input_shape=test_data.shape[1:],
output_dim=OUTPUT_DIM,
hidden_layer_sizes=[N_GEOMETRIC_NEURONS],
init_axis_angle=init_axis_angle,
transformed_parameters=transformed_parameters,
print_hidden_layer_output=False).to(device)
# get the model output:
output = steerable_model(test_data)
ancestor_output = model(test_data)
# compute the model accuracy for the perturbed rotated shapes:
acc = score(output.detach(), test_label)
accs[distortion].append(acc)
ancestor_acc = score(ancestor_output.detach(), test_label)
ancestor_accs[distortion].append(ancestor_acc)
# compute the L1 distance between the hidden activations:
hidden_activations = steerable_model.hidden_layer_activations.cpu().numpy()
dist = np.linalg.norm(hidden_activations - gt_hidden_activations, ord=1, axis=1)
dist = np.mean(dist)
dists[distortion].append(dist)
ancestor_hidden_activations = model.hidden_layer_activations.cpu().numpy()
ancestor_dist = np.linalg.norm(ancestor_hidden_activations - gt_hidden_activations, ord=1, axis=1)
ancestor_dist = np.mean(ancestor_dist)
ancestor_dists[distortion].append(ancestor_dist)
# if n % 10 == 0:
# print('\nexperiment #%d/%d' % (n+1, n_trials))
# print('\nadditive_uniform_noise:\n', noise)
# print('\ngt_rotation:\n', gt_rotation)
# print('\nacc: %.3f' % acc)
print()
print('ancestor_acc: %.4f +/- %.4f' % (np.mean(ancestor_accs[distortion]), np.std(ancestor_accs[distortion])))
print('acc: %.4f +/- %.4f' % (np.mean(accs[distortion]), np.std(accs[distortion])))
print()
print('ancestor L1 dist: %.4f +/- %.4f' % (np.mean(ancestor_dists[distortion]), np.std(ancestor_dists[distortion])))
print('L1 dist: %.4f +/- %.4f' % (np.mean(dists[distortion]), np.std(dists[distortion])))
print(end='\n\n')
distortion: 0.0 ancestor_acc: 0.4728 +/- 0.3403 acc: 1.0000 +/- 0.0000 ancestor L1 dist: 8.1021 +/- 4.1254 L1 dist: 0.0000 +/- 0.0000 distortion: 0.05 ancestor_acc: 0.4761 +/- 0.3424 acc: 1.0000 +/- 0.0000 ancestor L1 dist: 8.0639 +/- 4.0796 L1 dist: 0.3319 +/- 0.0472 distortion: 0.1 ancestor_acc: 0.4928 +/- 0.3501 acc: 1.0000 +/- 0.0000 ancestor L1 dist: 7.9400 +/- 3.9921 L1 dist: 0.6625 +/- 0.0953 distortion: 0.2 ancestor_acc: 0.4595 +/- 0.3439 acc: 0.9999 +/- 0.0040 ancestor L1 dist: 8.3054 +/- 3.8510 L1 dist: 1.3219 +/- 0.1936 distortion: 0.3 ancestor_acc: 0.4756 +/- 0.3439 acc: 0.9971 +/- 0.0187 ancestor L1 dist: 8.2549 +/- 3.7809 L1 dist: 2.0030 +/- 0.3098 distortion: 0.5 ancestor_acc: 0.4445 +/- 0.3193 acc: 0.9491 +/- 0.0772 ancestor L1 dist: 8.6461 +/- 3.4424 L1 dist: 3.3332 +/- 0.4837
np.random.seed(SEED)
torch.manual_seed(SEED)
n_trials = 1000
epochs = 300
# print_period = 50 # for online optimization
# distortion rotation angle parameters (in degrees):
distortion_std = 10
distortion_means = [0, 5, 10, 15, 30]
all_results = dict()
for distortion_mean in distortion_means:
print('\n\n\ndistortion = %d +/- %d degrees' % (distortion_mean, distortion_std))
gt_axis_angles = []
init_axis_angles = []
gt_labels = []
init_outputs = []
init_predictions = []
init_losses = []
init_accs = []
init_dists = []
final_outputs = []
final_predictions = []
final_losses = []
final_accs = []
final_dists = []
optimized_axis_angles = []
for n in tqdm(range(n_trials)):
# 1) Randomly transform the train data:
gt_axis_angle = random_axis_angle()
gt_axis_angles.append(gt_axis_angle)
gt_rotation = torch_rotation_matrix(gt_axis_angle).to(device).float()
# print('\ngt_rotation:\n', gt_rotation)
# select a sample:
shape_label = np.random.choice(len(Xtrain))
# print('\nshape_label:', shape_label)
gt_labels.append(shape_label)
test_data = Xtrain.reshape(-1, 3) @ gt_rotation.T
test_data = test_data.reshape(Xtrain.shape)[shape_label:shape_label+1]
test_label = Ytrain[shape_label:shape_label+1]
# 2) Initialize steerable model parameters with distorted ground truth axis-angle
# create a "distortion" axis-angle:
distortion_angle = np.radians(distortion_std)*np.random.randn() + np.radians(distortion_mean)
# print('\ndistortion_angle:', np.degrees(distortion_angle))
distortion_axis_angle = random_axis_angle(angle=distortion_angle)
distortion_matrix = torch_rotation_matrix(distortion_axis_angle)
# by multiplying distortion_matrix with the gt_rotation matrix,
# we can control the rotation angle randomness:
distorted_rotation = distortion_matrix @ gt_rotation.cpu()
init_r = R.from_matrix(distorted_rotation.cpu().numpy())
init_axis_angle = init_r.as_rotvec()
init_axis_angles.append(init_axis_angle)
# construct the steerable model with init_axis_angle:
steerable_model = build_steerable_model(input_shape=test_data.shape[1:],
output_dim=OUTPUT_DIM,
hidden_layer_sizes=[N_GEOMETRIC_NEURONS],
init_axis_angle=init_axis_angle,
transformed_parameters=transformed_parameters,
print_hidden_layer_output=False).to(device)
# get initial model output, hidden_activations, acc and loss:
output = steerable_model(test_data)
init_output = output.detach().cpu().numpy()
init_outputs.append(init_output)
init_dist = np.linalg.norm(softmax(init_output) - softmax(gt_outs[shape_label]), ord=1)
init_dists.append(init_dist)
init_acc = score(output.detach(), test_label)
init_accs.append(init_acc)
init_loss = entropy(output, is_logits=True).item()
init_losses.append(init_loss)
# print('\ninit_acc: %.3f' % init_acc)
# collect the initial predictions:
init_prediction = torch.argmax(output.detach(), axis=1)
init_predictions.append(init_prediction.cpu().numpy())
# print('initial prediction:', init_prediction.cpu().numpy())
# print()
# 3) Otimize the entropy loss wrt the axis-angle parameters of the steerable model:
optimizer = optim.Adam(steerable_model.parameters(), lr=1e-2)
for i in range(epochs):
# compute the entropy loss:
output = steerable_model(test_data)
loss = entropy(output, is_logits=True)
# backpropagate:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# if i % print_period == 0:
# print('epoch: %d, loss: %.3f' % (i, loss.item()))
# store the final data:
final_losses.append(loss.item())
final_output = output.detach().cpu().numpy()
final_outputs.append(final_output)
final_dist = np.linalg.norm(softmax(final_output) - softmax(gt_outs[shape_label]), ord=1)
final_dists.append(final_dist)
final_acc = score(output.detach(), test_label)
final_accs.append(final_acc)
final_prediction = torch.argmax(output.detach(), axis=1)
final_predictions.append(final_prediction.cpu().numpy())
optimized_axis_angle = steerable_model.axis_angle.detach().cpu().numpy()
optimized_axis_angles.append(optimized_axis_angle)
# print('epoch: %d, loss: %.3f' % (i, loss.item()))
# print('\nfinal_acc: %.3f' % final_acc)
# print('final prediction:', final_prediction.cpu().numpy())
# print('\ngt_axis_angle:\n', gt_axis_angle)
# print('\ninit_axis_angle:\n', init_axis_angle)
# print('optimized_axis_angle:\n', optimized_axis_angle)
all_results[distortion_mean] = {
'gt_axis_angles': gt_axis_angles,
'gt_labels': gt_labels,
'init_axis_angles': init_axis_angles,
'init_outputs': init_outputs,
'init_predictions': init_predictions,
'init_losses': init_losses,
'init_accs': init_accs,
'init_dists': init_dists,
'final_outputs': final_outputs,
'final_predictions': final_predictions,
'final_losses': final_losses,
'final_accs': final_accs,
'final_dists': final_dists,
'optimized_axis_angles': optimized_axis_angles,
}
print('\nmean_init_acc: %.4f, mean_final_acc: %.4f'
% (np.mean(all_results[distortion_mean]['init_accs']), np.mean(all_results[distortion_mean]['final_accs'])))
print('\nmean_init_dist: %.4f +/- %.4f, mean_final_dist: %.4f +/- %.4f'
% (np.mean(all_results[distortion_mean]['init_dists']), np.std(all_results[distortion_mean]['init_dists']),
np.mean(all_results[distortion_mean]['final_dists']), np.std(all_results[distortion_mean]['final_dists'])))
distortion = 0 +/- 10 degrees
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [51:10<00:00, 3.07s/it] 0%| | 0/1000 [00:00<?, ?it/s]
mean_init_acc: 0.9990, mean_final_acc: 0.9990 mean_init_dist: 0.0072 +/- 0.0328, mean_final_dist: 0.0065 +/- 0.0313 distortion = 5 +/- 10 degrees
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [1:04:02<00:00, 3.84s/it] 0%| | 0/1000 [00:00<?, ?it/s]
mean_init_acc: 0.9980, mean_final_acc: 0.9990 mean_init_dist: 0.0092 +/- 0.0376, mean_final_dist: 0.0066 +/- 0.0315 distortion = 10 +/- 10 degrees
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [53:55<00:00, 3.24s/it] 0%| | 0/1000 [00:00<?, ?it/s]
mean_init_acc: 0.9980, mean_final_acc: 0.9980 mean_init_dist: 0.0133 +/- 0.0459, mean_final_dist: 0.0074 +/- 0.0441 distortion = 15 +/- 10 degrees
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [59:23<00:00, 3.56s/it] 0%| | 0/1000 [00:00<?, ?it/s]
mean_init_acc: 0.9920, mean_final_acc: 0.9930 mean_init_dist: 0.0250 +/- 0.0808, mean_final_dist: 0.0123 +/- 0.0822 distortion = 30 +/- 10 degrees
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [54:45<00:00, 3.29s/it]
mean_init_acc: 0.9360, mean_final_acc: 0.9410 mean_init_dist: 0.1059 +/- 0.2049, mean_final_dist: 0.0635 +/- 0.2323