Q <- function(z) {
2-sqrt(4+z**2)+z*asinh(z/2)
}
q_norm <- function(beta, beta_init, gamma) {
(abs(beta_init)+gamma**2)*Q(2*beta/(abs(beta_init)+gamma**2))
}
mt_norm <- function(beta, beta_init) {
sqrt(beta**2+beta_init**2)-abs(beta_init)
}
fig_a <-
expand.grid(
beta_init = c(1e-3, 1, 1e3),
beta = 10**(seq(-2,2, 0.01))
) %>%
as_tibble() %>%
mutate(
q_norm = q_norm(beta, beta_init, 1e-3),
mt_norm = mt_norm(beta, beta_init)
) %>%
pivot_longer(cols=c(q_norm, mt_norm)) %>%
filter(name == 'mt_norm') %>%
ggplot(aes(beta, value, color=factor(beta_init), group=beta_init)) +
geom_vline(xintercept=1, linetype='dashed', color='grey50') +
geom_line(size=0.25) +
scale_x_log10(breaks=c(.1, 10), labels=c('0.1', '10')) +
scale_y_log10(breaks=c(1e-5, 10), labels=c(parse(text='10^{-5}'), '10')) +
scale_color_manual(values=RColorBrewer::brewer.pal(6, 'Blues')[4:6]) +
my_theme +
labs(x=TeX('Magnitude'), y='Norm', color=TeX('Auxiliary\nmagnitude'))
fig_a

fig_b <-
expand.grid(
beta_init = c(1e-3, 1, 1e3),
beta = 10**(seq(-2,2, 0.01))
) %>%
as_tibble() %>%
mutate(
q_norm = q_norm(beta, beta_init, 1e-3),
mt_norm = mt_norm(beta, beta_init)
) %>%
pivot_longer(cols=c(q_norm, mt_norm)) %>%
filter(name == 'q_norm') %>%
ggplot(aes(beta, value, color=factor(beta_init), group=beta_init)) +
geom_vline(xintercept=1, linetype='dashed', color='grey50') +
geom_line(size=0.25) +
scale_x_log10(breaks=c(.1, 10), labels=c('0.1', '10')) +
scale_y_log10(breaks=c(1e-5, 10), labels=c(parse(text='10^{-5}'), '10')) +
scale_color_manual(values=RColorBrewer::brewer.pal(6, 'Blues')[4:6]) +
my_theme +
labs(x=TeX('Magnitude'), y='Norm', color=TeX('Auxiliary\nmagnitude'))
fig_b

import numpy as np
alpha = 0
beta_vals = np.logspace(-2, 2, 100, base=10)
theta_vals = [1., 0.99, 0.9, 0.0]
beta_aux_vals = [0.001, 1, 1000]
df = {'norm': [], 'beta': [], 'beta_aux': [], 'theta': []}
norms = np.zeros([len(theta_vals), len(beta_aux_vals), len(beta_vals)])
for a, theta in enumerate(theta_vals):
for b, beta_aux in enumerate(beta_aux_vals):
for c, beta in enumerate(beta_vals):
v_0 = alpha
m_0 = np.sqrt(beta_aux)
m = None
for root in np.roots([1, -m_0*theta, 0, v_0*beta, -beta**2]):
if np.abs(np.imag(root)) < 1e-6 and np.real(root) > 0:
m = np.real(root)
norm = (beta/m-v_0)**2+m**2 + m_0**2 - 2*m*m_0*theta
df['norm'].append(norm)
df['beta'].append(beta)
df['beta_aux'].append(beta_aux)
df['theta'].append(theta)
df_ft_relu <-
tibble(norm = as.double(py$df$norm), beta = as.double(py$df$beta), beta_aux = as.double(py$df$beta_aux), theta = as.double(py$df$theta))
fig_c <-
df_ft_relu %>%
mutate(theta = paste('Corr.:', theta) %>% factor(levels = paste('Corr.:', c(1, 0.99, 0.9, 0)))) %>%
ggplot(aes(beta, norm, color=factor(beta_aux), group=beta_aux)) +
geom_vline(xintercept=1, linetype='dashed', color='grey50') +
geom_line(size=0.25) +
facet_wrap(~theta, nrow=1) +
scale_x_log10(breaks=c(.1, 10), labels=c('0.1', '10')) +
scale_y_log10(breaks=c(1e-5, 10), labels = c(parse(text='10^{-5}'), '10')) +
scale_color_manual(values=RColorBrewer::brewer.pal(6, 'Blues')[4:6]) +
coord_cartesian(ylim = c(1e-6, NA)) +
my_theme +
labs(x=TeX('Magnitude'), y='Norm', color=TeX('Auxiliary\nmagnitude'))
fig_c

fig <-
fig_a + (fig_b + guides(color = guide_none())) +
(fig_c + guides(color=guide_none())) +
plot_layout(nrow=1, widths = c(1,1,4)) +
plot_annotation(tag_levels = 'a')
fig

ggsave('figures/neurips-penalties.pdf', width = width, height = 0.25*width, units = 'cm')
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font metrics unknown for character 0xa
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font metrics unknown for character 0xa