library(tidyverse)
library(patchwork)
library(scico)
library(DescTools)
library(xtable)
library(magrittr)
library(latex2exp)


# Preparation -------------------------------------------------------------

# Performance =============================================================

get_performance <- function(path) {
  lst <- list()
  
  for(folder in list.files(path)) {
    if('metrics.csv' %in% list.files(paste0(path, '/', folder))) {
      path_2 <- paste0(path, '/', folder, '/metrics.csv')
      info <- str_split(folder, '[=,_]')[[1]]
      init <- info[2]
      coupling <- as.double(info[4])
      it <- as.integer(info[6])
      df <- read_csv(path_2) %>%
        mutate(
          init = init,
          coupling = coupling,
          it = it
        )
      lst <- c(lst, list(df))
    }
  }
  
  performance <-
    bind_rows(lst) %>%
    gather(key = 'metric', value = 'value', train_loss, train_acc, val_loss, val_acc) %>%
    group_by(init, epoch, coupling, it, metric) %>%
    summarise(value = mean(value, na.rm = TRUE)) %>%
    separate(metric, c('data', 'metric')) %>% 
    ungroup()
  
  performance
}

performance <-
  bind_rows(
    get_performance('data/cifar-16-8/default') %>%
      mutate(`Nr. of Channels` = '8'),
    get_performance('data/cifar-16-16/default') %>%
      mutate(`Nr. of Channels` = '16'),
    get_performance('data/cifar-16-32/default') %>%
      mutate(`Nr. of Channels` = '32'),
    get_performance('data/cifar-16-64/default') %>%
      mutate(`Nr. of Channels` = '64')
  )

best_epoch <-
  performance %>%
  filter(metric == 'acc', data == 'val') %>%
  group_by(init, coupling, it, `Nr. of Channels`) %>%
  nest() %>%
  mutate(epoch = map_dbl(data, function(x) {
    x_best <- 
      x %>%
      filter(value == max(x$value, na.rm = TRUE))
    max(x_best$epoch)
  })) %>%
  select(-data)

best_performance <-
  performance %>%
  semi_join(best_epoch, by = c('init', 'coupling', 'it', 'Nr. of Channels'))

# Indices of iterative convergence ========================================

raw_sir <-
  bind_rows(
    read_csv('data/sir_metrics/cifar-16-8.csv') %>%
      mutate(`Nr. of Channels` = '8'),
    read_csv('data/sir_metrics/cifar-16-16.csv') %>%
      mutate(`Nr. of Channels` = '16'),
    read_csv('data/sir_metrics/cifar-16-32.csv') %>%
      mutate(`Nr. of Channels` = '32'),
    read_csv('data/sir_metrics/cifar-16-64.csv') %>%
      mutate(`Nr. of Channels` = '64')
  ) %>%
  mutate(init = str_split_fixed(model_folder, '[_,=]', 6)[,2],
         coupling = str_split_fixed(model_folder, '[_,=]', 6)[,4] %>% as.numeric(),
         it = str_split_fixed(model_folder, '[_,=]', 6)[,6] %>% as.numeric()) %>%
  mutate(block = case_when(
    metric == 'layer_dropout' ~ (layer %% 16)+1,
    metric == 'early_late_readout' ~ layer %% 33
  ),
  stage = if_else(metric == 'layer_dropout', floor(layer/16), floor(layer/33))+1,
  metric = case_when(
    metric == 'layer_dropout' ~ 'layer_dropout',
    block > 16 ~ 'late_readout',
    block <= 16 ~ 'early_readout'
  ),
  block = if_else(metric == 'late_readout', block - 16, block))

raw_sir <-
  bind_rows(
    raw_sir,
    raw_sir %>%
      filter(metric == 'early_readout', block == 16) %>%
      mutate(metric = 'late_readout', block = 0)
  ) %>%
  filter(criterion == 'acc') %>%
  mutate(
    error_rate = 1-value,
    metric = case_when(
      metric == 'early_readout' ~ 'Convergence Index',
      metric == 'late_readout' ~ 'Divergence Index',
      metric == 'layer_dropout' ~ 'Recurrence Index'
    ) %>%
    factor(levels = c('Convergence Index', 'Recurrence Index', 'Divergence Index'))
  )

stagewise_sir <-
  raw_sir %>%
  group_by(init, coupling, it, `Nr. of Channels`, metric, stage) %>%
  nest() %>%
  mutate(auc = map2_dbl(metric, data, function(x, y) {
    ordered <-
      y %>%
      arrange(block)
    values <- ordered$error_rate
    if(x == 'Divergence Index') {
      min_value <- values[1]
    }
    else {
      min_value <- values[length(values)]
    }
    values <- (values-min_value)/(0.9-min_value)
    auc <- AUC(seq(0, 1, length.out = length(values)), values)
    if(x != 'Divergence Index'){
      auc <- 1-auc
    }
    auc
  })) %>%
  select(-data) %>%
  ungroup()

full_sir <-
  stagewise_sir %>%
  group_by(init, coupling, it, `Nr. of Channels`, metric) %>%
  summarise(auc = mean(auc))

# Effective number of parameters ==========================================

n_params <- function(n_channels, init, coupling, start_at=0) {
  unique_layers <- if_else((init == 'r') & (coupling == 1), 1, 16)
  params <- 2*((1+4+16)*n_channels^2*9+(1+2+4)*n_channels)*unique_layers
  params + n_unshared_params(n_channels, start_at=start_at)
}

n_unshared_params <- function(n_channels, start_at=0) {
  params <- 3*n_channels*9+2*n_channels
  params <- params + 2*n_channels*2*(n_channels*9+3/2)
  params <- params + 2*2*n_channels*(4*n_channels+3/2)*9
  params <- params + 2*4*n_channels + 4*n_channels*10 + 16*n_channels*9
  params <- 2*((1+4+16)*n_channels^2*9+(1+2+4)*n_channels)*start_at
  params
}

eff_params <-
  read_csv('data/effective-params.csv') %>%
  mutate(
    init = str_split_fixed(key, '[_,=]', n=8)[,2],
    coupling = str_split_fixed(key, '[_,=]', n=8)[,4] %>% as.double(),
    it = str_split_fixed(key, '[_,=]', n=8)[,6] %>% as.numeric(),
    epoch =  str_split_fixed(key, '[_,=]', n=8)[,8] %>% as.integer(),
    `Nr. of Channels` =
      folder %>%
      str_split_fixed('-', n = 3) %>%
      extract(,3) %>%
      str_split_fixed('_', n = 2) %>%
      extract(,1)
  ) %>%
  select(-key, -folder, -stage) %>%
  group_by(init, coupling, it, epoch, `Nr. of Channels`) %>%
  summarise_all(sum) %>%
  mutate(
    rec_summary = mean + dev + n_unshared_params(as.integer(`Nr. of Channels`)),
    nonrec_summary = n_params(as.integer(`Nr. of Channels`), init, coupling)
  )

eff_params <-
  eff_params %>%
  group_by(init, coupling, it, `Nr. of Channels`) %>%
  summarise(epoch = max(epoch)) %>%
  semi_join(
    eff_params, ., by = c('coupling', 'init', 'it', 'Nr. of Channels', 'epoch')
  ) %>%
  ungroup()


# Figure 2 ----------------------------------------------------------------

df_fig_2_a <-
  raw_sir %>%
  filter(init == 'nr', criterion == 'acc', `Nr. of Channels` == '16')

fig_2_a <-
  df_fig_2_a %>%
  ggplot(aes(block, error_rate, color = factor(stage))) +
  geom_point(alpha = 0.2, show.legend=FALSE) +
  geom_line(aes(group = paste(stage, it)), linetype = 'dashed', alpha = 0.2, show.legend=FALSE) +
  geom_point(data = df_fig_2_a %>% filter(it==0), show.legend=FALSE) +
  geom_line(data = df_fig_2_a %>% filter(it==0), show.legend=FALSE) +
  facet_wrap(~metric, nrow=3, scales='free_x') +
  scale_y_continuous(labels=scales::percent, breaks = c(0, 0.5, 0.9), limits = c(0,1)) +
  scale_color_brewer(palette = 'Dark2') +
  labs(x = '# additional evaluations', y = 'Error rate on CIFAR-10') +
  theme_minimal()

df_fig_2_b <-
  bind_rows(
    stagewise_sir %>%
      mutate(stage = paste('Stage:', as.character(stage))),
    full_sir %>%
      mutate(stage = 'Average')
  ) %>%
  mutate(stage = factor(stage, levels = c('Stage: 1', 'Stage: 2', 'Stage: 3', 'Average'))) %>%
  filter(init == 'nr', `Nr. of Channels` == '16')

fig_2_b <-
  df_fig_2_b %>%
  ggplot(aes(stage, auc, color = stage)) +
  geom_jitter(alpha = 0.5, position = position_dodge(width = 1)) +
  geom_point(position = position_dodge(width = 1), size = 2,
             data = df_fig_2_b %>% filter(it == 0)) +
  scale_y_continuous(limits = c(0, 1), breaks = c(0, 0.5, 1.)) +
  scale_x_discrete(breaks = NULL) +
  facet_wrap(~metric, ncol=1, scales = 'free') +
  scale_color_brewer(palette = 'Dark2') +
  coord_flip() +
  labs(x = NULL, y = NULL, color = NULL) +
  theme_minimal()

fig_2_c <-
  full_sir %>%
  inner_join(
    performance %>%
      filter(metric == 'acc', data == 'val') %>%
      select(-metric) %>%
      semi_join(best_epoch,
                by = c('init', 'coupling', 'it', 'Nr. of Channels', 'epoch')) %>%
      select(-epoch),
    by = c('coupling', 'init', 'it', 'Nr. of Channels')
  ) %>%
  ungroup() %>%
  filter(init == 'nr', `Nr. of Channels` == 16,
         metric %in% c('Convergence Index', 'Divergence Index')) %>%
  ggplot(aes(auc, 1-value)) +
  geom_point() +
  theme_minimal() +
  labs(y = 'Error rate on CIFAR-10', x = NULL) +
  scale_y_log10(labels = scales::percent, limits = c(0.0475, 0.055)) +
  facet_wrap(~metric, ncol = 1, scales = 'free_x') +
  scale_x_continuous(limits = c(0, 1), breaks = c(0, 0.5, 1))

(fig_2_a + fig_2_b + fig_2_c &
  theme(strip.text = element_text(size = 10, color = 'black'))) +
  plot_layout(guides = 'collect') +
  ggsave('figures/fig-2-raw.pdf', width = 8, height = 4)


# Figure 3 ----------------------------------------------------------------

fig_3_a <-
  full_sir %>%
  ungroup() %>%
  mutate(coupling = paste0(coupling, '\n', if_else(init=='nr', '(non-rec.)', '(rec.)')) %>%
           factor(
             levels = c('0\n(non-rec.)', '0\n(rec.)', '0.25\n(rec.)',
                        '0.5\n(rec.)', '0.9\n(rec.)', '1\n(rec.)')),
         `Nr. of Channels` = factor(`Nr. of Channels`,
                                    levels = c('8', '16', '32', '64'))) %>%
  ggplot(aes(coupling, auc, color = `Nr. of Channels`)) +
  geom_jitter(width=0.2, alpha=0.5) +
  stat_summary(geom = 'point', size = 3, fun.y = 'mean') +
  stat_summary(mapping = aes(group = paste(`Nr. of Channels`, metric)),
               geom = 'line', fun.y = 'mean') +
  facet_wrap(~metric, scales='free') +
  scale_color_scico_d(palette='berlin', drop = TRUE) +
  labs(x = 'Coupling parameter (Initialization)', y = 'AUC',
       shape = NULL, linetype = NULL, color = '# channels') +
  theme_minimal()

df_fig_3_b <-
  full_sir %>%
  inner_join(
    performance %>%
      filter(metric == 'acc', data == 'val') %>%
      select(-metric) %>%
      semi_join(best_epoch,
                by = c('init', 'coupling', 'it', 'Nr. of Channels', 'epoch')) %>%
      select(-epoch),
    by = c('coupling', 'init', 'it', 'Nr. of Channels')
  ) %>%
  inner_join(
    eff_params,
    by = c('coupling', 'init', 'it', 'Nr. of Channels')
  ) %>%
  filter(metric != 'Recurrence Index')

df_fig_3_b_2 <-
  df_fig_3_b %>%
  group_by(nonrec_summary, metric) %>%
  nest() %>%
  mutate(
    predictions = map(data, function(x) {
      model <- lm(log(1-value) ~ auc, x)
      x %>%
        mutate(prediction = exp(predict(model))) %>%
        filter((auc == min(auc)) |(auc == max(auc))) %>%
        select(auc, prediction)
    })
  ) %>%
  select(-data) %>%
  unnest(cols = c(predictions))

fig_3_b <-
  df_fig_3_b %>%
  ggplot(aes(auc, 1-value, color=nonrec_summary)) +
  geom_point() +
  geom_line(aes(auc, prediction, group=factor(nonrec_summary)),
            data = df_fig_3_b_2) +
  scale_color_viridis_c(
    trans = scales::log10_trans(),
    option = 'magma',
    breaks = c(1e5, 1e7),
    labels = TeX(c('$10^5$', '$10^7$'))
  ) +
  theme_minimal() +
  facet_wrap(~metric, nrow=1, scales='free_x') +
  labs(color = '# parameters', y = 'Error rate', x = NULL) +
  scale_y_log10(labels = scales::percent) +
  theme(strip.text = element_text(size = 10, color = 'black'),
        #strip.placement = 'outside',
        axis.title = element_text(size = 10))

df_fig_3_c <-
  eff_params %>%
  select(-epoch) %>%
  inner_join(
    performance %>%
      filter(metric == 'acc', data == 'val') %>%
      semi_join(best_epoch,
                by = c('coupling', 'init', 'it', 'Nr. of Channels', 'epoch')) %>%
      select(-epoch),
    by = c('coupling', 'init', 'it', 'Nr. of Channels')
  ) %>%
  gather('metric', 'parameters', nonrec_summary, rec_summary) %>%
  mutate(coupling = 
           paste0(coupling, if_else(init == 'nr', '\n(non-rec.)', '\n(rec.)')) %>%
           factor(levels = c('0\n(non-rec.)', '0\n(rec.)', '0.25\n(rec.)',
                             '0.5\n(rec.)', '0.9\n(rec.)', '1\n(rec.)')),
         metric = case_when(
           metric == 'nonrec_summary' ~ '# raw parameters',
           metric == 'rec_summary' ~ '# effective parameters'
         ) %>%
           factor(levels = c('# raw parameters', '# effective parameters')))

fig_3_c <-
  df_fig_3_c %>%
  ggplot(aes(parameters, 1-value, color = coupling)) +
  geom_line(
    data =
      df_fig_3_c %>%
      group_by(`Nr. of Channels`, coupling, metric) %>%
      summarise(parameters = mean(parameters), value = mean(value))
  ) +
  geom_point(alpha = 0.5) +
  scale_x_log10(breaks = c(1e5, 1e7), labels = TeX(c('$10^5$', '$10^7'))) +
  scale_color_viridis_d(option = 'cividis') +
  facet_wrap(~metric, scales = 'free_x') +
  theme_minimal() +
  scale_y_log10(labels = scales::percent) +
  theme(strip.text = element_text(size = 10, color = 'black')) +
  labs(color = 'Coupling parameter', y = 'Error rate on CIFAR-10', x = NULL)

((fig_3_a + theme(axis.text.x = element_text(size = 6)))/ (fig_3_b + fig_3_c) &
    theme(strip.text = element_text(size = 10, color = 'black'),
          axis.title = element_text(size = 10),
          legend.position = 'top')) +
  plot_layout(guides = 'collect') +
  ggsave('figures/fig-3-raw.pdf', width = 8, height = 6)

full_sir %>%
  inner_join(
    performance %>%
      filter(metric == 'acc', data == 'val') %>%
      select(-metric) %>%
      semi_join(best_epoch,
                by = c('init', 'coupling', 'it', 'Nr. of Channels', 'epoch')) %>%
      select(-epoch),
    by = c('coupling', 'init', 'it', 'Nr. of Channels')
  ) %>%
  ungroup() %>%
  filter(init == 'nr', `Nr. of Channels` == 16,
         metric %in% c('Convergence Index', 'Divergence Index')) %>%
  ggplot(aes(auc, 1-value)) +
  geom_point() +
  theme_minimal() +
  labs(y = 'Error rate on CIFAR-10', x = NULL) +
  scale_y_continuous(labels = scales::percent, limits = c(0.045, 0.055),
                     breaks = c(0.045, 0.05, 0.055)) +
  scale_x_continuous(limits = c(0.25, 0.75), breaks = c(0.25, 0.5, 0.75)) + 
  facet_wrap(~metric, ncol = 1, strip.position = 'bottom', scales = 'free_x') +
  theme(strip.placement = 'outside') +
  ggsave('figures/fig-2-c-compact.pdf', width = 1.5, height = 2.5)
