In this notebook, we show how a topological loss can be combined with a linear embedding procedure, as to regularize the embedding and better reflect the topological---in this case circular---prior.
We start by setting the working directory and importing the necessary libraries.
# Set working directory
import os
os.chdir("..")
# Handling arrays and data.frames
import pandas as pd
import numpy as np
# Loading R objects into python
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
# Pytorch compatible topology layer and losses
import torch
from topologylayer.nn import AlphaLayer
from Code.losses import DiagramLoss, pca_loss, ortho_loss
# Random sampling for topological loss
import random
# Ordinary and topologically regularized PCA embedding
from Code.topembed import PCA
# Plotting
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import cm
# Quantitative evaluation
from sklearn.svm import SVC
from Code.evaluation import evaluate_embeddings
# Representative cycle analysis with alpha-filtrations
import networkx as nx
import diode
import dionysus
# Tracking computation times
import time
%matplotlib inline
We start by loading the data and visualize it by means of its ordinary PCA embedding.
# Load the data
file_name = os.path.join("Data", "CellCycle.rds")
cell_info = ro.r["readRDS"](file_name)
cell_info = dict(zip(cell_info.names, list(cell_info)))
pandas2ri.activate()
data = ro.conversion.rpy2py(cell_info["expression"])
t = list(ro.conversion.rpy2py(cell_info["cell_info"])
.rename(columns={"milestone_id": "group_id"}).loc[:,"group_id"])
pandas2ri.deactivate()
print("Data shape: " + str(data.shape))
# Conduct ordinary PCA embedding
Y_pca, W_pca = PCA(data, random_state=42)
# View the data through its PCA embedding
fig, ax = plt.subplots()
sns.scatterplot(x=Y_pca[:,0], y=Y_pca[:,1], s=50, hue=t, palette="husl")
ax.get_legend().remove()
plt.show()
Data shape: (264, 6812) Time for embedding: 00:00:00
We now show how we can bias a linear embedding using a loss function that captures our topological prior.
The model we will use for this learns a linear projection $W$, which is optimized for the following three losses:
As a topological loss, we will use the persistence of the most prominent cycle in our embedding. It is important to multiply this by a factor $\lambda_{\mathrm{top}} <0$, since we want this persistence to be high. To obtain this loss, we require an additional layer that constructs the alpha complex from the embedding, from which subsequently persistent homology is computed.
# Define topological loss
def g(p): return p[1] - p[0] # function that returns the persistence d - b of a point (b, d)
TopLayer = AlphaLayer(maxdim=1) # alpha complex layer
CircularPersistence = DiagramLoss(dim=1, j=1, g=g) # compute persistence of most prominent cycle
lambda_top = -1e2 # scalar factor that trades off embedding and topological loss
top_frac = 0.25 # sample fraction for which the topological loss is computed
# Construct topological loss function
def top_loss(output):
if top_frac < 1:
sample = random.sample(range(output.shape[0]), int(output.shape[0] * top_frac))
output = output[sample,:]
dgminfo = TopLayer(output)
loss = lambda_top * CircularPersistence(dgminfo)
return loss
We can now conduct the topologically regularized linear embedding as follows.
# Learning hyperparameters
num_epochs = 1000
learning_rate = 5e-4
# Conduct topological regularization
Y_top, W_top, losses_top = PCA(data, top_loss=top_loss, num_epochs=num_epochs,
learning_rate=learning_rate, random_state=42)
# View topologically regularized embedding
fig, ax = plt.subplots()
sns.scatterplot(x=Y_top[:,0], y=Y_top[:,1], s=50, hue=t, palette="husl")
ax.get_legend().remove()
plt.show()
[epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1029.116089, total loss: -1022.410156] [epoch 100] [emb. loss: 6.794199, ortho. loss: 39.050903, top. loss: -2206.719482, total loss: -2160.874268] [epoch 200] [emb. loss: 6.871301, ortho. loss: 28.582264, top. loss: -3853.351318, total loss: -3817.897705] [epoch 300] [emb. loss: 6.911356, ortho. loss: 15.225312, top. loss: -4771.732422, total loss: -4749.595703] [epoch 400] [emb. loss: 6.944156, ortho. loss: 78.957901, top. loss: -5473.927246, total loss: -5388.025391] [epoch 500] [emb. loss: 6.963151, ortho. loss: 68.097542, top. loss: -3985.763672, total loss: -3910.702881] [epoch 600] [emb. loss: 6.969546, ortho. loss: 54.162918, top. loss: -5400.137207, total loss: -5339.004883] [epoch 700] [emb. loss: 6.980327, ortho. loss: 28.453640, top. loss: -5472.357422, total loss: -5436.923340] [epoch 800] [emb. loss: 6.993236, ortho. loss: 55.716778, top. loss: -5505.524414, total loss: -5442.814453] [epoch 900] [emb. loss: 6.992912, ortho. loss: 29.378899, top. loss: -5253.819824, total loss: -5217.448242] [epoch 1000] [emb. loss: 6.990539, ortho. loss: 6.560929, top. loss: -4963.142578, total loss: -4949.591309] Time for embedding: 00:00:34
We observe that we can regularize our linear embedding through the topological prior, obtaining a much more prominent cycle, while maintaining a nearly identical reconstruction error.
For comparison, we also conduct the same topological optimization procedure directly on the initialized embedding.
# Learning hyperparameters
num_epochs = 1000
learning_rate = 5e-4
# Conduct topological optimization
Y_opt, W_opt, losses_opt = PCA(data, emb_loss=False, top_loss=top_loss, num_epochs=num_epochs,
learning_rate=learning_rate, random_state=42)
# View topologically optimized embedding
fig, ax = plt.subplots()
sns.scatterplot(x=Y_opt[:,0], y=Y_opt[:,1], s=50, hue=t, palette="husl")
ax.get_legend().remove()
plt.show()
[epoch 1] [emb. loss: 0.000000, ortho. loss: 0.001194, top. loss: -1029.116089, total loss: -1029.114868] [epoch 100] [emb. loss: 0.000000, ortho. loss: 60.290764, top. loss: -2210.136719, total loss: -2149.845947] [epoch 200] [emb. loss: 0.000000, ortho. loss: 34.799610, top. loss: -2403.161133, total loss: -2368.361572] [epoch 300] [emb. loss: 0.000000, ortho. loss: 43.243652, top. loss: -5054.873047, total loss: -5011.629395] [epoch 400] [emb. loss: 0.000000, ortho. loss: 49.574078, top. loss: -5114.975098, total loss: -5065.400879] [epoch 500] [emb. loss: 0.000000, ortho. loss: 46.421947, top. loss: -3617.493896, total loss: -3571.072021] [epoch 600] [emb. loss: 0.000000, ortho. loss: 101.706757, top. loss: -5488.114746, total loss: -5386.408203] [epoch 700] [emb. loss: 0.000000, ortho. loss: 33.808533, top. loss: -5564.112793, total loss: -5530.304199] [epoch 800] [emb. loss: 0.000000, ortho. loss: 51.712521, top. loss: -5432.167969, total loss: -5380.455566] [epoch 900] [emb. loss: 0.000000, ortho. loss: 58.593304, top. loss: -5877.037109, total loss: -5818.443848] [epoch 1000] [emb. loss: 0.000000, ortho. loss: 51.424702, top. loss: -5112.462402, total loss: -5061.037598] Time for embedding: 00:00:18
We observe that the results are highly similar.
First, we evaluate the different losses (embedding and topological) for all final embeddings.
n_samples = 250 # number of samples for approximating (expected value of) topological loss
random.seed(42)
print("\033[1mLosses for pca embedding: \033[0m")
print("Reconstruction: " + str(losses_top["embedding"][0])) # PCA initialization gives first embedding loss
print("Orthonormality: " + str(ortho_loss(torch.tensor(W_pca)).item()))
print("Topological: " + str(np.mean([top_loss(torch.tensor(Y_pca).type(torch.float)) for _ in range(n_samples)])
/ np.abs(lambda_top)) + "\n")
print("\033[1mLosses for topologically optimized pca embedding: \033[0m")
print("Reconstruction: " + str(pca_loss(torch.tensor(data - data.mean(axis=0)),
torch.tensor(W_opt), torch.tensor(Y_opt)).item()))
print("Orthonormality: " + str(ortho_loss(torch.tensor(W_opt)).item()))
print("Topological: " + str(np.mean([top_loss(torch.tensor(Y_opt).type(torch.float)) for _ in range(n_samples)])
/ np.abs(lambda_top)) + "\n")
print("\033[1mLosses for topologically regularized pca embedding: \033[0m")
print("Reconstruction: " + str(pca_loss(torch.tensor(data - data.mean(axis=0)),
torch.tensor(W_top), torch.tensor(Y_top)).item()))
print("Orthonormality: " + str(ortho_loss(torch.tensor(W_top)).item()))
print("Topological: " + str(np.mean([top_loss(torch.tensor(Y_top).type(torch.float)) for _ in range(n_samples)])
/ np.abs(lambda_top)))
Losses for pca embedding: Reconstruction: 6.7047271728515625 Orthonormality: 1.831467351571864e-15 Topological: -13.417724609375 Losses for topologically optimized pca embedding: Reconstruction: 7.001532540225682 Orthonormality: 0.004188908729702234 Topological: -50.91328125 Losses for topologically regularized pca embedding: Reconstruction: 6.990493646422851 Orthonormality: 0.004286207724362612 Topological: -49.7297021484375
Finally, we compare if the topologically regularized embedding improves on the ordinary PCA embedding for predicting data point labels.
# Machine learning model to be used for label prediction
Ys = {"pca": Y_pca, "top. opt.": Y_opt, "top. reg.": Y_top}
model = SVC()
scoring = "accuracy"
# Hyperparameters for quantitative evaluation
ntimes = 100
test_frac = 0.1
params = {"C":[0.01, 0.1, 1, 10, 100]}
# Obtain performances over multiple train-test splits
performances = evaluate_embeddings(Ys, t, model, scoring, params=params, stratify=t,
ntimes=ntimes, test_frac=test_frac, random_state=42)
# View resulting performances
pd.concat([pd.DataFrame({"mean":performances.mean(axis=0)}),
pd.DataFrame({"std":performances.std(axis=0)})], axis=1)\
.style.highlight_max(subset="mean", color="lightgreen", axis=0)
mean | std | |
---|---|---|
pca | 0.787778 | 0.070402 |
top. opt. | 0.792963 | 0.072771 |
top. reg. | 0.807037 | 0.071076 |
Persistent homology can now be used to study the topological information from the embedded data within an exploratory data analysis setting. In this case, it allows one to conveniently obtain and study a representation. We will do this for both our ordinary PCA embedding, as well as our topologically regularized embedding.
We first obtain a representative cycle from the alpha-filtration as follows.
# Compute persistent homology and obtain a representative cocycle
simplices = diode.fill_alpha_shapes(Y_pca)
filtration = dionysus.Filtration(simplices)
PH = dionysus.homology_persistence(filtration)
dgms = dionysus.init_diagrams(PH, filtration)
pt = max(dgms[1], key=lambda pt: pt.death - pt.birth)
# Obtain a representation of the most prominent cycle in the embedding
cycle_raw = PH[PH.pair(pt.data)]
cycle = [s for s in cycle_raw]
cycle = np.array([list(filtration[s.index]) for s in cycle_raw])
# View the representation of the most prominent cycle in the topologically regularized embedding
fig, ax = plt.subplots()
for e in cycle:
plt.plot([Y_pca[e[0], 0], Y_pca[e[1], 0]], [Y_pca[e[0], 1], Y_pca[e[1], 1]],
linewidth=2, color="black", alpha=0.75, zorder=0)
sns.scatterplot(x=Y_pca[:,0], y=Y_pca[:,1], s=50, hue=t, palette="husl", zorder=1)
plt.show()
We can now project the entire set of embedded data points on the representative cycle as follows.
# Random permutation to better see the overlap in the projection
np.random.seed(69)
random_order = np.random.permutation(Y_pca.shape[0])
# Projection of topologically regularized data embedding on representative cycle
P = np.zeros(Y_pca.shape) # data projection
startpoint = [0 for p in Y_pca] # edge startpoint for each projected point
endpoint = [0 for p in Y_pca] # edge endpoint for each projected point
mu = [0 for p in Y_pca] # normalized distance of projected point to startpoint
for idx1 in range(Y_pca.shape[0]):
dist = np.inf
for idx2 in range(cycle.shape[0]):
this_mu = max(0, min(1, np.sum((Y_pca[idx1,:] - Y_pca[cycle[idx2, 0],:]) *
(Y_pca[cycle[idx2, 1],:] - Y_pca[cycle[idx2, 0],:])) /
np.sum((Y_pca[e[0],:] - Y_pca[e[1],:])**2)))
projection = Y_pca[cycle[idx2, 0],:] + this_mu * (Y_pca[cycle[idx2, 1],:] - Y_pca[cycle[idx2, 0],:])
this_dist = np.linalg.norm(Y_pca[idx1,:] - projection)
if this_dist < dist:
P[idx1,:] = projection
startpoint[idx1] = cycle[idx2, 0]
endpoint[idx1] = cycle[idx2, 1]
mu[idx1] = this_mu
dist = this_dist
# View the projection of the embedded data points on the representative cycle
fig, ax = plt.subplots()
for e in cycle:
plt.plot([Y_pca[e[0], 0], Y_pca[e[1], 0]], [Y_pca[e[0], 1], Y_pca[e[1], 1]],
linewidth=2, color="black", alpha=0.25)
sns.scatterplot(x=P[random_order, 0], y=P[random_order, 1], s=50,
hue=[t[idx] for idx in random_order], palette="husl", hue_order=["G1", "G2M", "S"])
ax.get_legend().remove()
plt.show()
Finally, we use this projection to obtain circular coordinates for the entire data.
# First obtain coordinates for the points on the representative cycle
G = nx.Graph()
G.add_edges_from(cycle)
cycle_nodes = np.unique(cycle)
start_node = cycle_nodes[np.argmax(Y_pca[cycle_nodes,0])]
end_node = list(G.neighbors(start_node))[np.argmax(Y_pca[list(G.neighbors(start_node)),1])]
G.remove_edge(start_node, end_node)
path = nx.shortest_path(G, start_node, end_node)
C = np.zeros([Y_pca.shape[0]])
for idx in range(1, len(path)):
C[path[idx]] = np.linalg.norm(Y_pca[path[idx],:] - Y_pca[path[idx - 1],:]) + C[path[idx - 1]]
# Obtain a circular coordinate for each data point
for idx in range(Y_pca.shape[0]):
C[idx] = (1 - mu[idx]) * C[startpoint[idx]] + mu[idx] * C[endpoint[idx]]
C = C.astype("float") * 2 * np.pi / (max(C) + np.linalg.norm(Y_pca[path[len(path) - 1],:] - Y_pca[path[0],:]))
# View the points in the topologically regularized embedding using hue to show the circular coordinates
fig, ax = plt.subplots()
sns.scatterplot(x=Y_pca[:,0], y=Y_pca[:,1], s=50, c=C, cmap="hsv")
plt.show()
We first obtain a representative cycle from the alpha-filtration as follows.
# Compute persistent homology and obtain a representative cocycle
simplices = diode.fill_alpha_shapes(Y_top)
filtration = dionysus.Filtration(simplices)
PH = dionysus.homology_persistence(filtration)
dgms = dionysus.init_diagrams(PH, filtration)
pt = max(dgms[1], key=lambda pt: pt.death - pt.birth)
# Obtain a representation of the most prominent cycle in the embedding
cycle_raw = PH[PH.pair(pt.data)]
cycle = [s for s in cycle_raw]
cycle = np.array([list(filtration[s.index]) for s in cycle_raw])
# View the representation of the most prominent cycle in the topologically regularized embedding
fig, ax = plt.subplots()
for e in cycle:
plt.plot([Y_top[e[0], 0], Y_top[e[1], 0]], [Y_top[e[0], 1], Y_top[e[1], 1]],
linewidth=2, color="black", alpha=0.75)
sns.scatterplot(x=Y_top[:,0], y=Y_top[:,1], s=50, hue=t, palette="husl")
plt.show()
We can now project the entire set of embedded data points on the representative cycle as follows.
# Random permutation to better see the overlap in the projection
np.random.seed(42)
random_order = np.random.permutation(Y_top.shape[0])
# Projection of topologically regularized data embedding on representative cycle
P = np.zeros(Y_top.shape) # data projection
startpoint = [0 for p in Y_top] # edge startpoint for each projected point
endpoint = [0 for p in Y_top] # edge endpoint for each projected point
mu = [0 for p in Y_top] # normalized distance of projected point to startpoint
for idx1 in range(Y_top.shape[0]):
dist = np.inf
for idx2 in range(cycle.shape[0]):
this_mu = max(0, min(1, np.sum((Y_top[idx1,:] - Y_top[cycle[idx2, 0],:]) *
(Y_top[cycle[idx2, 1],:] - Y_top[cycle[idx2, 0],:])) /
np.sum((Y_top[e[0],:] - Y_top[e[1],:])**2)))
projection = Y_top[cycle[idx2, 0],:] + this_mu * (Y_top[cycle[idx2, 1],:] - Y_top[cycle[idx2, 0],:])
this_dist = np.linalg.norm(Y_top[idx1,:] - projection)
if this_dist < dist:
P[idx1,:] = projection
startpoint[idx1] = cycle[idx2, 0]
endpoint[idx1] = cycle[idx2, 1]
mu[idx1] = this_mu
dist = this_dist
# View the projection of the embedded data points on the representative cycle
fig, ax = plt.subplots()
for e in cycle:
plt.plot([Y_top[e[0], 0], Y_top[e[1], 0]], [Y_top[e[0], 1], Y_top[e[1], 1]],
linewidth=2, color="black", alpha=0.25)
sns.scatterplot(x=P[random_order, 0], y=P[random_order, 1], s=50,
hue=[t[idx] for idx in random_order], palette="husl", hue_order=["G1", "G2M", "S"])
ax.get_legend().remove()
plt.show()
We now use this projection to obtain circular coordinates for the entire data.
# First obtain coordinates for the points on the representative cycle
G = nx.Graph()
G.add_edges_from(cycle)
cycle_nodes = np.unique(cycle)
start_node = cycle_nodes[np.argmax(Y_top[cycle_nodes,0])]
end_node = list(G.neighbors(start_node))[np.argmax(Y_top[list(G.neighbors(start_node)),1])]
G.remove_edge(start_node, end_node)
path = nx.shortest_path(G, start_node, end_node)
C = np.zeros([Y_top.shape[0]])
for idx in range(1, len(path)):
C[path[idx]] = np.linalg.norm(Y_top[path[idx],:] - Y_top[path[idx - 1],:]) + C[path[idx - 1]]
# Obtain a circular coordinate for each data point
for idx in range(Y_top.shape[0]):
C[idx] = (1 - mu[idx]) * C[startpoint[idx]] + mu[idx] * C[endpoint[idx]]
C = C.astype("float") * 2 * np.pi / (max(C) + np.linalg.norm(Y_top[path[len(path) - 1],:] - Y_top[path[0],:]))
# View the points in the topologically regularized embedding using hue to show the circular coordinates
fig, ax = plt.subplots()
sns.scatterplot(x=Y_top[:,0], y=Y_top[:,1], s=50, c=C, cmap="hsv")
plt.show()
Finally, we explore how topological regularization reacts to different sampling fractions $f_{\mathcal{S}}$ and repeats $n_{\mathcal{S}}$. The different embeddings are obtained as follows.
# Definine sampling fractions and repeats
top_fracs = [0.25, 0.5, 0.75, 1]
repeats = [1, 10, 25]
# Construct embedding for each sampling fraction and repeat combination
Y_samplings = {}
for f in top_fracs:
Y_samplings[f] = {}
for n in repeats:
Y_samplings[f][n] = {}
if f < 1 or n == 1:
# Define the topological loss function for this sampling fraction and repeat combination
def this_sampling_loss(output):
loss = torch.tensor(0).type(torch.float)
for _ in range(n):
sample = random.sample(range(output.shape[0]), int(output.shape[0] * f))
dgminfo = TopLayer(output[sample,:])
loss += CircularPersistence(dgminfo)
loss = lambda_top * loss / n
return loss
# Track total embedding time
start_time = time.time()
# Conduct embedding
print("\033[1mConducting embedding for sampling fraction " + str(f) +
" with " + str(n) + " repeats\033[0m")
Y_samplings[f][n]["emb"] = PCA(data, top_loss=this_sampling_loss, num_epochs=num_epochs,
learning_rate=learning_rate, random_state=42)[0]
print("\n")
# Store total embedding time
Y_samplings[f][n]["time"] = time.time() - start_time
else:
Y_samplings[f][n]["emb"] = Y_samplings[f][1]["emb"]
Y_samplings[f][n]["time"] = Y_samplings[f][1]["time"]
Conducting embedding for sampling fraction 0.25 with 1 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1029.116089, total loss: -1022.410156] [epoch 100] [emb. loss: 6.794199, ortho. loss: 39.050903, top. loss: -2206.719482, total loss: -2160.874268] [epoch 200] [emb. loss: 6.871301, ortho. loss: 28.582264, top. loss: -3853.351318, total loss: -3817.897705] [epoch 300] [emb. loss: 6.911356, ortho. loss: 15.225312, top. loss: -4771.732422, total loss: -4749.595703] [epoch 400] [emb. loss: 6.944156, ortho. loss: 78.957901, top. loss: -5473.927246, total loss: -5388.025391] [epoch 500] [emb. loss: 6.963151, ortho. loss: 68.097542, top. loss: -3985.763672, total loss: -3910.702881] [epoch 600] [emb. loss: 6.969546, ortho. loss: 54.162918, top. loss: -5400.137207, total loss: -5339.004883] [epoch 700] [emb. loss: 6.980327, ortho. loss: 28.453640, top. loss: -5472.357422, total loss: -5436.923340] [epoch 800] [emb. loss: 6.993236, ortho. loss: 55.716778, top. loss: -5505.524414, total loss: -5442.814453] [epoch 900] [emb. loss: 6.992912, ortho. loss: 29.378899, top. loss: -5253.819824, total loss: -5217.448242] [epoch 1000] [emb. loss: 6.990539, ortho. loss: 6.560929, top. loss: -4963.142578, total loss: -4949.591309] Time for embedding: 00:00:32 Conducting embedding for sampling fraction 0.25 with 10 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1175.824707, total loss: -1169.118774] [epoch 100] [emb. loss: 6.790902, ortho. loss: 72.855339, top. loss: -3278.387939, total loss: -3198.741699] [epoch 200] [emb. loss: 6.877039, ortho. loss: 36.844570, top. loss: -4887.672852, total loss: -4843.951172] [epoch 300] [emb. loss: 6.947292, ortho. loss: 74.844597, top. loss: -5128.823242, total loss: -5047.031250] [epoch 400] [emb. loss: 6.978375, ortho. loss: 10.693149, top. loss: -5341.971680, total loss: -5324.300293] [epoch 500] [emb. loss: 6.996393, ortho. loss: 100.883926, top. loss: -5417.264648, total loss: -5309.384277] [epoch 600] [emb. loss: 7.003494, ortho. loss: 90.685127, top. loss: -5375.957031, total loss: -5278.268555] [epoch 700] [emb. loss: 7.006982, ortho. loss: 77.177429, top. loss: -5275.533691, total loss: -5191.349121] [epoch 800] [emb. loss: 7.011427, ortho. loss: 33.695168, top. loss: -5543.634766, total loss: -5502.928223] [epoch 900] [emb. loss: 7.012042, ortho. loss: 48.824753, top. loss: -5524.329590, total loss: -5468.492676] [epoch 1000] [emb. loss: 7.018599, ortho. loss: 42.345371, top. loss: -5718.063477, total loss: -5668.699707] Time for embedding: 00:02:32 Conducting embedding for sampling fraction 0.25 with 25 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1260.585693, total loss: -1253.879761] [epoch 100] [emb. loss: 6.817028, ortho. loss: 99.548416, top. loss: -4788.339844, total loss: -4681.974609] [epoch 200] [emb. loss: 6.892045, ortho. loss: 33.699192, top. loss: -5324.106445, total loss: -5283.515137] [epoch 300] [emb. loss: 6.926610, ortho. loss: 37.064960, top. loss: -5962.265625, total loss: -5918.273926] [epoch 400] [emb. loss: 6.946918, ortho. loss: 23.145182, top. loss: -5930.549805, total loss: -5900.457520] [epoch 500] [emb. loss: 6.961986, ortho. loss: 52.085140, top. loss: -5497.619141, total loss: -5438.571777] [epoch 600] [emb. loss: 6.972620, ortho. loss: 31.867531, top. loss: -5670.656250, total loss: -5631.815918] [epoch 700] [emb. loss: 6.976828, ortho. loss: 100.536942, top. loss: -5838.679199, total loss: -5731.165527] [epoch 800] [emb. loss: 6.980758, ortho. loss: 13.700308, top. loss: -5455.537598, total loss: -5434.856445] [epoch 900] [emb. loss: 6.981705, ortho. loss: 45.371075, top. loss: -5625.951660, total loss: -5573.599121] [epoch 1000] [emb. loss: 6.982512, ortho. loss: 32.492855, top. loss: -5843.659180, total loss: -5804.183594] Time for embedding: 00:05:51 Conducting embedding for sampling fraction 0.5 with 1 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1005.261963, total loss: -998.556030] [epoch 100] [emb. loss: 6.804258, ortho. loss: 52.676682, top. loss: -4761.550293, total loss: -4702.069336] [epoch 200] [emb. loss: 6.856610, ortho. loss: 55.447510, top. loss: -6269.050781, total loss: -6206.746582] [epoch 300] [emb. loss: 6.880493, ortho. loss: 48.983494, top. loss: -6358.332520, total loss: -6302.468750] [epoch 400] [emb. loss: 6.898883, ortho. loss: 47.144676, top. loss: -4951.137695, total loss: -4897.094238] [epoch 500] [emb. loss: 6.912720, ortho. loss: 68.769531, top. loss: -7004.153320, total loss: -6928.471191] [epoch 600] [emb. loss: 6.944167, ortho. loss: 57.718338, top. loss: -6296.945312, total loss: -6232.282715] [epoch 700] [emb. loss: 6.938440, ortho. loss: 49.248798, top. loss: -6499.519531, total loss: -6443.332520] [epoch 800] [emb. loss: 6.946087, ortho. loss: 41.352200, top. loss: -6552.672363, total loss: -6504.374023] [epoch 900] [emb. loss: 6.963080, ortho. loss: 64.740143, top. loss: -6733.777832, total loss: -6662.074707] [epoch 1000] [emb. loss: 6.964814, ortho. loss: 55.813107, top. loss: -6791.526855, total loss: -6728.749023] Time for embedding: 00:00:45 Conducting embedding for sampling fraction 0.5 with 10 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1117.711182, total loss: -1111.005249] [epoch 100] [emb. loss: 6.822650, ortho. loss: 141.338440, top. loss: -6148.577637, total loss: -6000.416504] [epoch 200] [emb. loss: 6.879549, ortho. loss: 43.860413, top. loss: -6809.450684, total loss: -6758.710938] [epoch 300] [emb. loss: 6.909528, ortho. loss: 59.103874, top. loss: -6670.020996, total loss: -6604.007812] [epoch 400] [emb. loss: 6.928156, ortho. loss: 86.793015, top. loss: -6959.907715, total loss: -6866.186523] [epoch 500] [emb. loss: 6.936523, ortho. loss: 112.138298, top. loss: -6809.162598, total loss: -6690.087891] [epoch 600] [emb. loss: 6.938022, ortho. loss: 65.750793, top. loss: -6899.799316, total loss: -6827.110352] [epoch 700] [emb. loss: 6.946169, ortho. loss: 29.287359, top. loss: -6629.323242, total loss: -6593.089844] [epoch 800] [emb. loss: 6.945699, ortho. loss: 214.211227, top. loss: -7263.698242, total loss: -7042.541504] [epoch 900] [emb. loss: 6.947989, ortho. loss: 77.101959, top. loss: -6996.594727, total loss: -6912.544922] [epoch 1000] [emb. loss: 6.949709, ortho. loss: 90.907501, top. loss: -6781.801758, total loss: -6683.944336] Time for embedding: 00:05:04 Conducting embedding for sampling fraction 0.5 with 25 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1232.841064, total loss: -1226.135132] [epoch 100] [emb. loss: 6.822948, ortho. loss: 71.643867, top. loss: -6513.216309, total loss: -6434.749512] [epoch 200] [emb. loss: 6.875990, ortho. loss: 73.378647, top. loss: -6883.404297, total loss: -6803.149414] [epoch 300] [emb. loss: 6.903165, ortho. loss: 38.359657, top. loss: -6834.091309, total loss: -6788.828613] [epoch 400] [emb. loss: 6.917817, ortho. loss: 55.901146, top. loss: -7091.703125, total loss: -7028.884277] [epoch 500] [emb. loss: 6.923595, ortho. loss: 25.987488, top. loss: -6886.497559, total loss: -6853.586426] [epoch 600] [emb. loss: 6.932153, ortho. loss: 42.604115, top. loss: -7089.366699, total loss: -7039.830566] [epoch 700] [emb. loss: 6.936269, ortho. loss: 37.121876, top. loss: -7002.799805, total loss: -6958.741699] [epoch 800] [emb. loss: 6.942491, ortho. loss: 50.332947, top. loss: -7004.552734, total loss: -6947.277344] [epoch 900] [emb. loss: 6.942492, ortho. loss: 77.929703, top. loss: -7178.382324, total loss: -7093.510254] [epoch 1000] [emb. loss: 6.943398, ortho. loss: 42.242138, top. loss: -7106.476074, total loss: -7057.290527] Time for embedding: 00:11:26 Conducting embedding for sampling fraction 0.75 with 1 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1017.175110, total loss: -1010.469177] [epoch 100] [emb. loss: 6.795101, ortho. loss: 48.520687, top. loss: -6431.034180, total loss: -6375.718262] [epoch 200] [emb. loss: 6.848928, ortho. loss: 49.549816, top. loss: -6953.195312, total loss: -6896.796387] [epoch 300] [emb. loss: 6.893303, ortho. loss: 87.150078, top. loss: -6987.529785, total loss: -6893.486328] [epoch 400] [emb. loss: 6.910672, ortho. loss: 35.913963, top. loss: -7024.847656, total loss: -6982.022949] [epoch 500] [emb. loss: 6.913319, ortho. loss: 37.680935, top. loss: -7229.406738, total loss: -7184.812500] [epoch 600] [emb. loss: 6.926466, ortho. loss: 77.884842, top. loss: -7497.289062, total loss: -7412.477539] [epoch 700] [emb. loss: 6.930403, ortho. loss: 73.753166, top. loss: -7331.033203, total loss: -7250.349609] [epoch 800] [emb. loss: 6.944964, ortho. loss: 25.071203, top. loss: -6928.309082, total loss: -6896.292969] [epoch 900] [emb. loss: 6.948291, ortho. loss: 42.189098, top. loss: -7219.701172, total loss: -7170.563965] [epoch 1000] [emb. loss: 6.944431, ortho. loss: 58.217113, top. loss: -7238.951660, total loss: -7173.790039] Time for embedding: 00:01:00 Conducting embedding for sampling fraction 0.75 with 10 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1143.817383, total loss: -1137.111450] [epoch 100] [emb. loss: 6.806490, ortho. loss: 34.259598, top. loss: -6361.401855, total loss: -6320.335938] [epoch 200] [emb. loss: 6.864360, ortho. loss: 48.831562, top. loss: -6863.896973, total loss: -6808.201172] [epoch 300] [emb. loss: 6.894897, ortho. loss: 65.669640, top. loss: -7279.370117, total loss: -7206.805664] [epoch 400] [emb. loss: 6.912440, ortho. loss: 46.027889, top. loss: -7415.944336, total loss: -7363.003906] [epoch 500] [emb. loss: 6.918342, ortho. loss: 6.206989, top. loss: -7194.991211, total loss: -7181.865723] [epoch 600] [emb. loss: 6.926075, ortho. loss: 15.658889, top. loss: -7319.033691, total loss: -7296.448730] [epoch 700] [emb. loss: 6.929006, ortho. loss: 28.909306, top. loss: -7614.259277, total loss: -7578.420898] [epoch 800] [emb. loss: 6.930554, ortho. loss: 50.749401, top. loss: -7343.166992, total loss: -7285.486816] [epoch 900] [emb. loss: 6.930087, ortho. loss: 75.291969, top. loss: -7146.584473, total loss: -7064.362305] [epoch 1000] [emb. loss: 6.927831, ortho. loss: 66.628464, top. loss: -7399.719727, total loss: -7326.163574] Time for embedding: 00:07:19 Conducting embedding for sampling fraction 0.75 with 25 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1145.170654, total loss: -1138.464722] [epoch 100] [emb. loss: 6.807723, ortho. loss: 28.695347, top. loss: -6806.242676, total loss: -6770.739746] [epoch 200] [emb. loss: 6.864093, ortho. loss: 96.747849, top. loss: -7051.673340, total loss: -6948.061523] [epoch 300] [emb. loss: 6.891009, ortho. loss: 25.569204, top. loss: -7224.132324, total loss: -7191.671875] [epoch 400] [emb. loss: 6.906007, ortho. loss: 36.772480, top. loss: -7478.990723, total loss: -7435.312012] [epoch 500] [emb. loss: 6.915853, ortho. loss: 23.008736, top. loss: -7432.014160, total loss: -7402.089355] [epoch 600] [emb. loss: 6.920774, ortho. loss: 57.299809, top. loss: -7505.043945, total loss: -7440.823242] [epoch 700] [emb. loss: 6.921452, ortho. loss: 40.476604, top. loss: -7596.955566, total loss: -7549.557617] [epoch 800] [emb. loss: 6.924866, ortho. loss: 65.514778, top. loss: -7640.169922, total loss: -7567.730469] [epoch 900] [emb. loss: 6.927379, ortho. loss: 16.394178, top. loss: -7546.125000, total loss: -7522.803223] [epoch 1000] [emb. loss: 6.924320, ortho. loss: 34.368752, top. loss: -7697.785156, total loss: -7656.492188] Time for embedding: 00:17:16 Conducting embedding for sampling fraction 1 with 1 repeats [epoch 1] [emb. loss: 6.704727, ortho. loss: 0.001194, top. loss: -1017.716614, total loss: -1011.010681] [epoch 100] [emb. loss: 6.778495, ortho. loss: 31.168331, top. loss: -6753.857422, total loss: -6715.910645] [epoch 200] [emb. loss: 6.835660, ortho. loss: 14.663120, top. loss: -6998.564941, total loss: -6977.066406] [epoch 300] [emb. loss: 6.870476, ortho. loss: 50.959778, top. loss: -7974.586426, total loss: -7916.756348] [epoch 400] [emb. loss: 6.888778, ortho. loss: 99.523521, top. loss: -7788.938965, total loss: -7682.526855] [epoch 500] [emb. loss: 6.900532, ortho. loss: 39.350815, top. loss: -8044.899902, total loss: -7998.648438] [epoch 600] [emb. loss: 6.909189, ortho. loss: 45.929207, top. loss: -7810.061523, total loss: -7757.223145] [epoch 700] [emb. loss: 6.913458, ortho. loss: 23.038944, top. loss: -8035.033203, total loss: -8005.080566] [epoch 800] [emb. loss: 6.919832, ortho. loss: 10.664638, top. loss: -7364.859863, total loss: -7347.275391] [epoch 900] [emb. loss: 6.921590, ortho. loss: 24.885702, top. loss: -8023.665527, total loss: -7991.858398] [epoch 1000] [emb. loss: 6.925286, ortho. loss: 45.186710, top. loss: -7881.346191, total loss: -7829.234375] Time for embedding: 00:01:12
We visuallize all embeddings as follows.
# Visualize all embeddings
fig, ax = plt.subplots(len(top_fracs), len(repeats), figsize=(len(repeats) * 5, len(top_fracs) * 3.75),
gridspec_kw={"hspace": 0.3})
for idx1, f in enumerate(top_fracs):
for idx2, n in enumerate(repeats):
sns.scatterplot(x=Y_samplings[f][n]["emb"][:,0], y=Y_samplings[f][n]["emb"][:,1],
s=50, hue=t, palette="husl", ax=ax[idx1][idx2])
ax[idx1][idx2].get_legend().remove()
ax[idx1][idx2].set_title("$f_S$ = {}, $n_S$ = {}".format(f, n))
plt.show()
We see that overall the embedding does not significantly vary with a change in these hyperparameters. However, the computation times may vary significantly, as illustrated below.
# Construct data set for visualization
pd_time = pd.DataFrame()
for f in top_fracs:
for n in repeats:
pd_time = pd.concat([pd_time, pd.DataFrame({"f_S":[f], "n_S":[n], "time":[Y_samplings[f][n]["time"]]})])
# Visualize computation times by sampling fraction and repeat
fig, ax = plt.subplots(1, 2, figsize=(15, 3.5))
sns.lineplot(data=pd_time.iloc[np.where(np.logical_or(pd_time["f_S"] != 1, pd_time["n_S"] == 1))[0],:],
x="f_S", y="time", hue="n_S", ax=ax[0], palette="Set2")
ax[0].set_title("Embedding time by sampling fraction $f_S$")
ax[0].set_xlabel("$f_S$")
ax[0].set_ylabel("time (s)")
sns.lineplot(data=pd_time, x="n_S", y="time", hue="f_S", ax=ax[1], palette="Set2")
ax[1].set_title("Embedding time by number of repeats $n_S$")
ax[1].set_xlabel("$n_S$")
ax[1].set_ylabel("time (s)")
plt.show()