Q <- function(z) {
  2-sqrt(4+z**2)+z*asinh(z/2)
}
q_norm <- function(beta, beta_init, gamma) {
    (abs(beta_init)+gamma**2)*Q(2*beta/(abs(beta_init)+gamma**2))
}
mt_norm <- function(beta, beta_init) {
  sqrt(beta**2+beta_init**2)-abs(beta_init)
}
fig_a <-
  expand.grid(
    beta_init = c(1e-3, 1, 1e3),
    beta = 10**(seq(-2,2, 0.01))
  ) %>%
  as_tibble() %>%
  mutate(
    q_norm = q_norm(beta, beta_init, 1e-3),
    mt_norm = mt_norm(beta, beta_init)
  ) %>%
  pivot_longer(cols=c(q_norm, mt_norm)) %>%
  filter(name == 'mt_norm') %>%
  ggplot(aes(beta, value, color=factor(beta_init), group=beta_init)) +
  geom_vline(xintercept=1, linetype='dashed', color='grey50') +
  geom_line(size=0.25) +
  scale_x_log10(breaks=c(.1, 10), labels=c('0.1', '10')) +
  scale_y_log10(breaks=c(1e-5, 10), labels=c(parse(text='10^{-5}'), '10')) +
  scale_color_manual(values=RColorBrewer::brewer.pal(6, 'Blues')[4:6]) +
  my_theme +
  labs(x=TeX('Magnitude'), y='Norm', color=TeX('Auxiliary\nmagnitude'))
fig_a

fig_b <-
  expand.grid(
    beta_init = c(1e-3, 1, 1e3),
    beta = 10**(seq(-2,2, 0.01))
  ) %>%
  as_tibble() %>%
  mutate(
    q_norm = q_norm(beta, beta_init, 1e-3),
    mt_norm = mt_norm(beta, beta_init)
  ) %>%
  pivot_longer(cols=c(q_norm, mt_norm)) %>%
  filter(name == 'q_norm') %>%
  ggplot(aes(beta, value, color=factor(beta_init), group=beta_init)) +
  geom_vline(xintercept=1, linetype='dashed', color='grey50') +
  geom_line(size=0.25) +
  scale_x_log10(breaks=c(.1, 10), labels=c('0.1', '10')) +
  scale_y_log10(breaks=c(1e-5, 10), labels=c(parse(text='10^{-5}'), '10')) +
  scale_color_manual(values=RColorBrewer::brewer.pal(6, 'Blues')[4:6]) +
  my_theme +
  labs(x=TeX('Magnitude'), y='Norm', color=TeX('Auxiliary\nmagnitude'))
fig_b

import numpy as np
alpha = 0
beta_vals = np.logspace(-2, 2, 100, base=10)
theta_vals = [1., 0.99, 0.9, 0.0]
beta_aux_vals = [0.001, 1, 1000]
df = {'norm': [], 'beta': [], 'beta_aux': [], 'theta': []}
norms = np.zeros([len(theta_vals), len(beta_aux_vals), len(beta_vals)])
for a, theta in enumerate(theta_vals):
    for b, beta_aux in enumerate(beta_aux_vals):
        for c, beta in enumerate(beta_vals):
            v_0 = alpha
            m_0 = np.sqrt(beta_aux)
            m = None
            for root in np.roots([1, -m_0*theta, 0, v_0*beta, -beta**2]):
                if np.abs(np.imag(root)) < 1e-6 and np.real(root) > 0:
                    m = np.real(root)
            norm = (beta/m-v_0)**2+m**2 + m_0**2 - 2*m*m_0*theta
            df['norm'].append(norm)
            df['beta'].append(beta)
            df['beta_aux'].append(beta_aux)
            df['theta'].append(theta)
df_ft_relu <-
  tibble(norm = as.double(py$df$norm), beta = as.double(py$df$beta), beta_aux = as.double(py$df$beta_aux), theta = as.double(py$df$theta))
fig_c <-
  df_ft_relu %>%
  mutate(theta = paste('Corr.:', theta) %>% factor(levels = paste('Corr.:', c(1, 0.99, 0.9, 0)))) %>%
  ggplot(aes(beta, norm, color=factor(beta_aux), group=beta_aux)) +
  geom_vline(xintercept=1, linetype='dashed', color='grey50') +
  geom_line(size=0.25) +
  facet_wrap(~theta, nrow=1) +
  scale_x_log10(breaks=c(.1, 10), labels=c('0.1', '10')) +
  scale_y_log10(breaks=c(1e-5, 10), labels = c(parse(text='10^{-5}'), '10')) +
  scale_color_manual(values=RColorBrewer::brewer.pal(6, 'Blues')[4:6]) +
  coord_cartesian(ylim = c(1e-6, NA)) +
  my_theme +
  labs(x=TeX('Magnitude'), y='Norm', color=TeX('Auxiliary\nmagnitude'))
fig_c

fig <-
  fig_a + (fig_b + guides(color = guide_none())) + 
  (fig_c + guides(color=guide_none())) + 
  plot_layout(nrow=1, widths = c(1,1,4)) + 
  plot_annotation(tag_levels = 'a')
fig

ggsave('figures/neurips-penalties.pdf', width = width, height = 0.25*width, units = 'cm')
## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa

## Warning in grid.Call(C_textBounds, as.graphicsAnnot(x$label), x$x, x$y, : font
## metrics unknown for character 0xa
## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font metrics unknown for character 0xa

## Warning in grid.Call.graphics(C_text, as.graphicsAnnot(x$label), x$x, x$y, :
## font metrics unknown for character 0xa