# Preparation -------------------------------------------------------------

get_best_epoch <- function(df) {
  df_epoch <-
    df %>%
    filter(metric == 'acc', data == 'val') %>%
    group_by(init, coupling, it, `Nr. of Channels`) %>%
    nest() %>%
    mutate(epoch = map_dbl(data, function(x) {
      x_best <- 
        x %>%
        filter(value == max(x$value, na.rm = TRUE))
      max(x_best$epoch)
    })) %>%
    select(-data)
  
  df %>%
    semi_join(df_epoch,
              by = c('init', 'coupling', 'it', 'Nr. of Channels', 'epoch'))
}

cifar_performance <-
  bind_rows(
    get_performance('data/cifar_rb-16-8/default') %>%
      mutate(`Nr. of Channels` = '8', folder = 'cifar_rb-16-8'),
    get_performance('data/cifar_rb-16-16/default') %>%
      mutate(`Nr. of Channels` = '16', folder = 'cifar_rb-16-16'),
    get_performance('data/cifar_rb-16-32/default') %>%
      mutate(`Nr. of Channels` = '32', folder = 'cifar_rb-16-32'),
    get_performance('data/cifar_rb-16-64/default') %>%
      mutate(`Nr. of Channels` = '64', folder = 'cifar_rb-16-64')
  ) %>%
  get_best_epoch() %>%
  mutate(rec_bn = 'gamma=0.1') %>%
  bind_rows(
    performance %>%
      mutate(rec_bn = 'gamma=1',
             folder = paste0('cifar-16-', `Nr. of Channels`)) %>%
      get_best_epoch()
  ) %>%
  mutate(method = 'uniform') %>%
  bind_rows(
    bind_rows(
      get_performance('data/cifar_tri-16-16/default') %>%
        mutate(`Nr. of Channels` = '16', folder = 'cifar_tri-16-16'),
      get_performance('data/cifar_tri-16-32/default') %>%
        mutate(`Nr. of Channels` = '32', folder = 'cifar_tri-16-32'),
      get_performance('data/cifar_tri-16-64/default') %>%
        mutate(`Nr. of Channels` = '64', folder = 'cifar_tri-16-64')
    ) %>%
      get_best_epoch() %>%
      mutate(rec_bn = 'gamma=1', method = 'triangular')
  ) %>%
  mutate(start_at = '0') %>%
  bind_rows(
    bind_rows(
      get_performance('data/cifar_startat5-16-8/default') %>%
        mutate(`Nr. of Channels` = '8', folder = 'cifar_startat5-16-8'),
      get_performance('data/cifar_startat5-16-16/default') %>%
        mutate(`Nr. of Channels` = '16', folder = 'cifar_startat5-16-16'),
      get_performance('data/cifar_startat5-16-32/default') %>%
        mutate(`Nr. of Channels` = '32', folder = 'cifar_startat5-16-32'),
      get_performance('data/cifar_startat5-16-64/default') %>%
        mutate(`Nr. of Channels` = '64', folder = 'cifar_startat5-16-64')
    ) %>%
      get_best_epoch() %>%
      mutate(rec_bn = 'gamma=1') %>%
      bind_rows(
        bind_rows(
          get_performance('data/cifar_startat5_rb-16-8/default') %>%
            mutate(`Nr. of Channels` = '8', folder = 'cifar_startat5_rb-16-8'),
          get_performance('data/cifar_startat5_rb-16-16/default') %>%
            mutate(`Nr. of Channels` = '16', folder = 'cifar_startat5_rb-16-16'),
          get_performance('data/cifar_startat5_rb-16-32/default') %>%
            mutate(`Nr. of Channels` = '32', folder = 'cifar_startat5_rb-16-32'),
          get_performance('data/cifar_startat5_rb-16-64/default') %>%
            mutate(`Nr. of Channels` = '64', folder = 'cifar_startat5_rb-16-64')
        ) %>%
          get_best_epoch() %>%
          mutate(rec_bn='gamma=0.1')
      ) %>%
      mutate(start_at = '5', method = 'uniform')
  ) %>%
  mutate(rec_bn = factor(rec_bn, levels = c('gamma=1', 'gamma=0.1')),
         method = factor(method, levels = c('uniform', 'triangular')),
         `Nr. of Channels` = `Nr. of Channels` %>%
           factor(levels = c('8', '16', '32', '64')))

eff_params <-
  read_csv('data/effective-params.csv') %>%
  mutate(
    init = str_split_fixed(key, '[_,=]', n=8)[,2],
    coupling = str_split_fixed(key, '[_,=]', n=8)[,4] %>% as.double(),
    it = str_split_fixed(key, '[_,=]', n=8)[,6] %>% as.numeric(),
    epoch =  str_split_fixed(key, '[_,=]', n=8)[,8] %>% as.integer()
  ) %>%
  select(-key, -stage) %>%
  group_by(init, coupling, it, epoch, folder) %>%
  summarise_all(sum)

eff_params <-
  eff_params %>%
  group_by(init, coupling, it, folder) %>%
  summarise(epoch = max(epoch)) %>%
  semi_join(
    eff_params, ., by = c('coupling', 'init', 'it', 'folder', 'epoch')
  ) %>%
  ungroup() %>%
  inner_join(cifar_performance, by = c('coupling', 'init', 'it', 'folder')) %>%
  mutate(
    n_channels = as.integer(as.character(`Nr. of Channels`)),
    rec_summary = mean + dev + n_unshared_params(n_channels,
                                                 start_at = as.integer(start_at)),
    nonrec_summary = n_params(n_channels, init, coupling,
                              start_at = as.integer(start_at))
  )

# Supplementary Figure 1 --------------------------------------------------

performance %>%
  filter(metric == 'acc', `Nr. of Channels` == '16') %>%
  mutate(
    error_rate = 1-value,
    data = if_else(data == 'train', 'Training data', 'Validation data'),
    `Coupling parameter` = paste0(coupling,
                      if_else(init=='nr', ' (non-rec.)', ' (rec.)'))
  ) %>%
  group_by(`Coupling parameter`, epoch, data) %>%
  summarise(error_rate = mean(error_rate)) %>%
  ggplot(aes(epoch, error_rate,
             color = `Coupling parameter`,
             fill = `Coupling parameter`)) +
  geom_line() +
  facet_grid(~data) +
  scale_color_viridis_d(option = 'cividis') +
  scale_fill_viridis_d(option = 'cividis') +
  theme_minimal() +
  scale_y_continuous(labels = scales::percent,
                     name = 'Error rate on CIFAR-10') +
  xlab('Epoch') +
  labs(color = 'Coupling parameter (Initialization)') +
  ggsave('figures/supp-fig-1.pdf', width = 8, height = 3)


# Supplementary Figure 2 --------------------------------------------------

supp_fig_2_a <-
  cifar_performance %>%
  filter(method == 'uniform', start_at == '0') %>%
  filter(metric == 'acc') %>%
  mutate(error_rate = 1-value,
         coupling = paste0(coupling, '\n(',
                           if_else(init == 'nr', 'non-rec.', 'rec.'), ')'),
         data = if_else(data == 'train', 'Training data', 'Validation data')) %>%
  ggplot(aes(coupling, error_rate, shape = rec_bn, linetype = rec_bn,
             color = `Nr. of Channels`)) +
  geom_point(alpha = 0.5) +
  stat_summary(geom = 'point', size = 3, fun.y = mean) +
  stat_summary(mapping = aes(group = paste(`Nr. of Channels`, rec_bn)),
               geom = 'line', fun.y = mean) +
  facet_wrap(~data) +
  scale_y_continuous(labels = scales::percent) +
  scale_color_scico_d(palette = 'berlin') +
  labs(x = 'Coupling parameter (Initialization)',
       y = 'Error rate on CIFAR-10',
       color = '# channels',
       shape = 'Batchnorm initialization',
       linetype = 'Batchnorm initialization') +
  theme_minimal()

supp_fig_2_b <-
  cifar_performance %>%
  filter(`Nr. of Channels` != '8', start_at == '0') %>%
  filter(metric == 'acc') %>%
  mutate(error_rate = 1-value,
         coupling = paste0(coupling, '\n(',
                           if_else(init == 'nr', 'non-rec.', 'rec.'), ')'),
         data = if_else(data == 'train', 'Training data', 'Validation data')) %>%
  ggplot(aes(coupling, error_rate, shape = method, linetype = method,
             color = `Nr. of Channels`)) +
  geom_point(alpha = 0.5) +
  stat_summary(geom = 'point', size = 3, fun.y = mean) +
  stat_summary(mapping = aes(group = paste(`Nr. of Channels`, method)),
               geom = 'line', fun.y = mean) +
  facet_wrap(~data) +
  scale_y_continuous(labels = scales::percent) +
  scale_color_scico_d(palette = 'berlin') +
  labs(x = 'Coupling parameter (Initialization)',
       y = 'Error rate on CIFAR-10',
       color = '# channels',
       shape = 'Kernel type',
       linetype = 'Kernel type') +
  theme_minimal()

supp_fig_2_c <-
  cifar_performance %>%
  filter(start_at == '5') %>%
  filter(metric == 'acc') %>%
  mutate(error_rate = 1-value,
         coupling = paste0(coupling, '\n(',
                           if_else(init == 'nr', 'non-rec.', 'rec.'), ')'),
         data = if_else(data == 'train', 'Training data', 'Validation data')) %>%
  ggplot(aes(coupling, error_rate, shape = rec_bn, linetype = rec_bn,
             color = `Nr. of Channels`)) +
  geom_point(alpha = 0.5) +
  stat_summary(geom = 'point', size = 3, fun.y = mean) +
  stat_summary(mapping = aes(group = paste(`Nr. of Channels`, rec_bn)),
               geom = 'line', fun.y = mean) +
  facet_wrap(~data) +
  scale_y_continuous(labels = scales::percent) +
  scale_color_scico_d(palette = 'berlin') +
  labs(x = 'Coupling parameter (Initialization)',
       y = 'Error rate on CIFAR-10',
       color = '# channels',
       shape = 'Batchnorm initialization',
       linetype = 'Batchnorm initialization') +
  theme_minimal()

(supp_fig_2_a / supp_fig_2_b / supp_fig_2_c) +
  plot_annotation(tag_levels = 'a') +
  ggsave('figures/supp-fig-2.pdf', width=8, height = 8)


# Supplementary Figure 3 --------------------------------------------------

eff_params %>%
  filter(data == 'val', metric == 'acc') %>%
  gather('metric', 'parameters', rec_summary, nonrec_summary) %>%
  mutate(coupling = paste0(coupling, ' (',
                           if_else(init=='nr', 'non-rec.', 'rec.'), ')')) %>%
  mutate(error_rate = 1-value) %>%
  mutate(metric = if_else(metric == 'nonrec_summary',
                          '# raw parameters',
                          '# effective parameters') %>%
           factor(levels = c('# raw parameters', '# effective parameters'))) %>%
  ggplot(aes(parameters, error_rate, color = coupling)) +
  geom_point() +
  scale_x_log10(breaks = c(1e5, 1e7), labels=TeX(c('$10^5$', '$10^7$')),
                name = NULL) +
  scale_y_log10(labels = scales::percent,
                name = 'Error rate on CIFAR-10') +
  scale_color_viridis_d(option = 'cividis',
                        name = 'Coupling parameter (Initialization)') +
  facet_wrap(~metric, strip.position = 'bottom') +
  theme_minimal() +
  theme(strip.placement = 'outside') +
  ggsave('figures/supp-fig-3.pdf', width=8, height=4)


# Supplementary Figure 4 --------------------------------------------------

eff_params %>%
  ggplot(aes(nonrec_summary, rec_summary,
             color = paste0(coupling, ' (',
                            if_else(init=='nr', 'non-rec.', 'rec.'), ')'))) +
  geom_abline(slope=1, color = 'grey', linetype = 'dashed') +
  geom_point() +
  scale_x_log10(name = '# raw parameters',
                breaks = c(1e5, 1e7), labels=TeX(c('$10^5$', '$10^7$'))) +
  scale_y_log10(name = '# effective parameters',
                breaks = c(1e5, 1e7), labels=TeX(c('$10^5$', '$10^7$'))) +
  scale_color_viridis_d(option = 'cividis', name = 'Coupling parameter (Initialization)') +
  theme_minimal() +
  ggsave('figures/supp-fig-4.pdf', width=8, height=4)

# Supplementary Tables 1 and 2 ---------------------------------------------------

mnist_performance <-
  bind_rows(
    get_performance('data/mnist-16-16/default') %>%
      mutate(`Nr. of Channels` = '16'),
    get_performance('data/mnist-16-32/default') %>%
      mutate(`Nr. of Channels` = '32'),
    get_performance('data/mnist-16-64/default') %>%
      mutate(`Nr. of Channels` = '64')
  ) %>%
  get_best_epoch()

mnist_performance %>%
  filter(data == 'val', metric == 'acc') %>%
  group_by(`Nr. of Channels`, init, coupling) %>%
  summarise(error_rate = paste(round(100*mean(1-value, na.rm = TRUE), 3), '%')) %>%
  rename(`# channels` = `Nr. of Channels`) %>%
  mutate(coupling = paste0(coupling, ' (',
                           if_else(init=='nr', 'non-rec.', 'rec.'), ')')) %>%
  ungroup() %>%
  select(-init) %>%
  spread('coupling', 'error_rate') %>%
  xtable() %>%
  print(include.rownames = FALSE)

digclut_performance <-
  bind_rows(
    get_performance('data/digitclutter_3_1-16-16/default') %>%
      mutate(`Nr. of Channels` = '16'),
    get_performance('data/digitclutter_3_1-16-32/default') %>%
      mutate(`Nr. of Channels` = '32'),
    get_performance('data/digitclutter_3_1-16-64/default') %>%
      mutate(`Nr. of Channels` = '64')
  ) %>%
  get_best_epoch()

digclut_performance %>%
  filter(data == 'val', metric == 'acc') %>%
  group_by(`Nr. of Channels`, init, coupling) %>%
  summarise(error_rate = paste(round(100*mean(1-value, na.rm = TRUE), 3), '%')) %>%
  rename(`# channels` = `Nr. of Channels`) %>%
  mutate(coupling = paste0(coupling, ' (',
                           if_else(init=='nr', 'non-rec.', 'rec.'), ')')) %>%
  ungroup() %>%
  select(-init) %>%
  spread('coupling', 'error_rate') %>%
  xtable() %>%
  print(include.rownames = FALSE)


# Supplementary Figure 5 --------------------------------------------------

raw_sir <-
  bind_rows(
    read_csv('data/sir_metrics/cifar_rb-16-8.csv') %>%
      mutate(folder = 'cifar_rb-16-8'),
    read_csv('data/sir_metrics/cifar_rb-16-16.csv') %>%
      mutate(folder = 'cifar_rb-16-16'),
    read_csv('data/sir_metrics/cifar_rb-16-32.csv') %>%
      mutate(folder = 'cifar_rb-16-32'),
    read_csv('data/sir_metrics/cifar_rb-16-64.csv') %>%
      mutate(folder = 'cifar_rb-16-64'),
    read_csv('data/sir_metrics/cifar_startat5-16-8.csv') %>%
      mutate(folder = 'cifar_startat5-16-8'),
    read_csv('data/sir_metrics/cifar_startat5-16-16.csv') %>%
      mutate(folder = 'cifar_startat5-16-16'),
    read_csv('data/sir_metrics/cifar_startat5-16-32.csv') %>%
      mutate(folder = 'cifar_startat5-16-32'),
    read_csv('data/sir_metrics/cifar_startat5-16-64.csv') %>%
      mutate(folder = 'cifar_startat5-16-64'),
    read_csv('data/sir_metrics/cifar_startat5_rb-16-8.csv') %>%
      mutate(folder = 'cifar_startat5_rb-16-8'),
    read_csv('data/sir_metrics/cifar_startat5_rb-16-16.csv') %>%
      mutate(folder = 'cifar_startat5_rb-16-16'),
    read_csv('data/sir_metrics/cifar_startat5_rb-16-32.csv') %>%
      mutate(folder = 'cifar_startat5_rb-16-32'),
    read_csv('data/sir_metrics/cifar_startat5_rb-16-64.csv') %>%
      mutate(folder = 'cifar_startat5_rb-16-64'),
    read_csv('data/sir_metrics/cifar_tri-16-16.csv') %>%
      mutate(folder = 'cifar_tri-16-16'),
    read_csv('data/sir_metrics/cifar_tri-16-32.csv') %>%
      mutate(folder = 'cifar_tri-16-32'),
    read_csv('data/sir_metrics/cifar_tri-16-64.csv') %>%
      mutate(folder = 'cifar_tri-16-64')
  ) %>%
  mutate(init = str_split_fixed(model_folder, '[_,=]', 6)[,2],
         coupling = str_split_fixed(model_folder, '[_,=]', 6)[,4] %>% as.numeric(),
         it = str_split_fixed(model_folder, '[_,=]', 6)[,6] %>% as.numeric()) %>%
  mutate(block = case_when(
    metric == 'layer_dropout' ~ (layer %% 16)+1,
    metric == 'early_late_readout' ~ layer %% 33
  ),
  stage = if_else(metric == 'layer_dropout', floor(layer/16), floor(layer/33))+1,
  metric = case_when(
    metric == 'layer_dropout' ~ 'layer_dropout',
    block > 16 ~ 'late_readout',
    block <= 16 ~ 'early_readout'
  ),
  block = if_else(metric == 'late_readout', block - 16, block))

raw_sir <-
  bind_rows(
    raw_sir,
    raw_sir %>%
      filter(metric == 'early_readout', block == 16) %>%
      mutate(metric = 'late_readout', block = 0)
  ) %>%
  filter(criterion == 'acc') %>%
  mutate(
    error_rate = 1-value,
    metric = case_when(
      metric == 'early_readout' ~ 'Convergence Index',
      metric == 'late_readout' ~ 'Divergence Index',
      metric == 'layer_dropout' ~ 'Recurrence Index'
    ) %>%
      factor(levels = c('Convergence Index', 'Recurrence Index', 'Divergence Index'))
  )

stagewise_sir <-
  raw_sir %>%
  group_by(init, coupling, it, folder, metric, stage) %>%
  nest() %>%
  mutate(auc = map2_dbl(metric, data, function(x, y) {
    ordered <-
      y %>%
      arrange(block)
    values <- ordered$error_rate
    if(x == 'Divergence Index') {
      min_value <- values[1]
    }
    else {
      min_value <- values[length(values)]
    }
    values <- (values-min_value)/(0.9-min_value)
    auc <- AUC(seq(0, 1, length.out = length(values)), values)
    if(x != 'Divergence Index'){
      auc <- 1-auc
    }
    auc
  })) %>%
  select(-data) %>%
  ungroup()

full_sir <-
  stagewise_sir %>%
  group_by(init, coupling, it, folder, metric) %>%
  summarise(auc = mean(auc)) %>%
  inner_join(
    cifar_performance %>%
      filter(metric == 'acc', data == 'val') %>%
      select(-metric, -data),
    by = c('init', 'coupling', 'it', 'folder'))

full_sir %>%
  ungroup() %>%
  mutate(setup = paste0('Start coupling at: ', start_at, '\n',
                        'Kernel type: ', method) %>%
           factor(levels =
                    c('Start coupling at: 0\nKernel type: uniform',
                      'Start coupling at: 0\nKernel type: triangular',
                      'Start coupling at: 5\nKernel type: uniform')),
         coupling = as.character(coupling)) %>%
  ggplot(aes(coupling, auc, shape = rec_bn, linetype = rec_bn,
             color = `Nr. of Channels`,
             group = paste(rec_bn, `Nr. of Channels`))) +
  stat_summary(geom='line', fun.y = mean) +
  stat_summary(geom='point', fun.y = mean) +
  facet_grid(setup~metric, scales = 'free') +
  scale_color_scico_d(palette = 'berlin') +
  theme_minimal() +
  labs(y = 'AUC', x = 'Coupling parameter',
       color = '# channels', shape = 'Batchnorm initialization',
       linetype = 'Batchnorm initialization') +
  theme(legend.position = 'top') +
  ggsave('figures/supp-fig-5.pdf', width=8, height=6)

