
library(abind)
library(tidyr)
library(dplyr)
library(ggplot2)
library(reshape2)
library(scales)
library(gridExtra)

#### Section 5. Single_server ----
load("single_server.Rdata")
myres<-myres[[5]]

cuti<-4;corr<-T
miall<-c(1:3,8);mis<-c('Logistic','SVM','DWD','CG')
plotfun<-function(data_lists,mi,list_id_labels,ry=c(0.05,0.95)){
  df <- do.call(rbind, lapply(1:length(data_lists), function(list_idx) {
    data.frame(
      list_id = list_idx,
      point = rep(1:8, unlist(lapply(data_lists[[list_idx]],length))),
      value = unlist(data_lists[[list_idx]])
    )
  }))
  summary_df <- df %>% 
    group_by(list_id, point) %>% 
    summarise(
      mean = mean(value),
      lwr = quantile(value, 0.025),
      upr = quantile(value, 0.975)
    )
  v_x <- edpxy
  summary_df$point_label <- as.factor(v_x[summary_df$point])
  
  line_types <- c(rep("solid",4),rep("longdash",3))
  mycolor<-c("#00BFC4", "#F8766D", "#00BA38", "#FFA500","#007AFF", "#C77CFF","#FF2D55")
  pics_mi <- 
    ggplot(summary_df, aes(x = point_label, y = mean, group = factor(list_id), color = factor(list_id))) +
    geom_line(position = position_dodge(0.2), aes(linetype = factor(list_id))) +
    geom_point(position = position_dodge(0.2)) +
    geom_errorbar(aes(ymin = lwr, ymax = upr), width = 0.2, position = position_dodge(0.2),linetype = "dashed") +
    labs(title = '', x = expression(epsilon), y = "Misclassification rate") +
    scale_color_manual(values = mycolor, labels = list_id_labels,name=NULL) +
    scale_linetype_manual(values = line_types, labels = list_id_labels,name=NULL) +
    scale_y_continuous(
      breaks = c(0.2, 0.4, 0.6, 0.8),
      labels = scales::percent_format(accuracy = 1)(c(0.2, 0.4, 0.6, 0.8))
    ) +
    theme_minimal() +
    coord_cartesian(ylim = ry) +
    theme(
      panel.border = element_rect(color = "black", fill = NA, size = 1),
      axis.text.x = element_text(color = "black"),
      axis.text.y = element_text(color = "black"),
      plot.title = element_text(size = 10),
      legend.position = c(1, 1), 
      legend.justification = c(1, 1),
      legend.key.width = unit(1.5, 'cm')
    ) +
    guides(color = guide_legend(override.aes = list(linetype = line_types))) # ,shape=NA,Combine legends
  return(pics_mi)
}
res_pert<-lapply(miall,function(i){
  # weak classifier
  vini<-lapply(myres,function(myi) unlist(lapply(myi,function(myj) myj[[1]][i,3+corr])) )
  # MA
  vpert1<-lapply(myres,function(myi) {
    x=unlist(lapply(myi,function(myj) myj[[2]][cuti,3+corr,i]));return(x[!is.na(x)])} )
  # MRMA
  vpert2<-lapply(myres,function(myi) {
    x=unlist(lapply(myi,function(myj) myj[[3]][cuti,3+corr,i]));return(x[!is.na(x)])} )
  # MR
  vpert3<-lapply(myres,function(myi) {
    x=unlist(lapply(myi,function(myj) myj[[4]][cuti,3+corr,i]));return(x[!is.na(x)])} )
  # voting
  vini1<-lapply(myres,function(myi) unlist(lapply(myi,function(myj) myj[[5]][i,1])) )
  # averaging
  vini2<-lapply(myres,function(myi) unlist(lapply(myi,function(myj) myj[[5]][i,2])) )
  # all_data
  ciall<-lapply(myres,function(myi) unlist(lapply(myi,function(myj) myj[[6]][i,2])))
  
  return(list(vini,vpert3,vpert1,vpert2,vini1,vini2,ciall))
})
list_id_labels<-c('Weak','MR','MA','MRMA','Voting','Averaging','All data')
pics<-lapply(1:4,function(mi) plotfun(res_pert[[mi]],mi,list_id_labels,
                                      ry=c(0.1,0.9)))#c(0.05,0.95)
for(i in 1:4){plot(pics[[i]])}


#### A.2 encoding ----
load("encoding.Rdata")
res_encoding<-lapply(myres[4:6],function(myresi){
  acc0<-c(NA,NA,Reduce('+',lapply(myresi,function(xi) xi$acc0))/length(myresi))
  acc1<-Reduce('+',lapply(myresi,function(xi) xi$acc1))/length(myresi)
  acc1<-rbind('ini'=acc0,acc1)
  return(acc1)
})
res_encoding<-t(do.call(rbind,res_encoding)[-c(5,9),miall][c(1,2,4,3,5,7,6,8,10,9),])*100
rownames(res_encoding)<-mis
colnames(res_encoding)

library(xtable)
latex_code <- xtable(res_encoding)
print(latex_code, type = "latex")

#### A.2 perturbation ----
load("single_server.Rdata")
myres<-myres[[5]]

## boxplot
res_pert<-lapply(miall,function(i){
  vini<-unlist(lapply(myres[[1]],function(myresi) myresi[[6]][i]))
  vpert<-lapply(myres,function(myi) unlist(lapply(myi,function(myj) myj[[6]][i])) )
  return(cbind(vini,do.call(cbind,vpert[length(vpert):1])))
})
par(mfrow=c(1,1),mar=c(2.2,2.5,1,0.5),mgp=c(1.3,0.1,0),tck=-0.01)
boxplot(do.call(cbind,res_pert),outline=F,ylab='Misclassification rate',xlab=expression(epsilon),
        xlim=c(1,36),ylim=c(0.05,0.95),xaxt='n',yaxt='n',main='')
axis(2,at=seq(0.1,0.9,0.2),labels=sprintf('%0.0f%%',seq(0.1,0.9,0.2)*100),cex.axis=0.6,las=1)
axis(1,at=1:36,labels=rep(c('Tanh',edpxy[8:1]),4),cex.axis=0.5)
abline(h=0.5,lty=3)
for(vi in 1:3){abline(v=9*vi+0.5,lty=3)}
text((0:3)*9+5,0.9,mis,cex=0.7,font=2)

## density
res_pert<-lapply(miall,function(i){
  vini<-unlist(lapply(myall,function(myresi) myresi[i]))
  vpert<-lapply(myres,function(myi) unlist(lapply(myi,function(myj) myj[i])) )
  return(cbind(vini,do.call(cbind,vpert[length(vpert):1])))
})
for(i in 1:4){
  data_df <- data.frame(matrix = rep(1:9, each = 500), value = c(1-res_pert[[i]]))
  data_df$matrix <- factor(data_df$matrix, levels = 1:9)
  levels(data_df$matrix) <- c('Tanh',edpxy[8:1])
  
  plot(ggplot(data_df, aes(x = value)) +
         geom_density() +
         facet_wrap(~matrix, scales = "free") +
         xlim(0,1) + labs(x = "classification accuracy rate r")+
         scale_y_continuous(breaks=function(x) round(seq(0,ceiling(max(x)),length.out=5))))
}

#
#### A.3 sample size balance ----
load("single_server.Rdata")
cuti<-8
list_id_labels<-paste0('Case ',1:(5))
plotfun<-function(data_lists,mi,list_id_labels,ry=c(0.05,0.95)){
  df <- do.call(rbind, lapply(1:length(data_lists), function(list_idx) {
    data.frame(
      list_id = list_idx,
      point = rep(1:8, unlist(lapply(data_lists[[list_idx]],length))),
      value = unlist(data_lists[[list_idx]])
    )
  }))
  summary_df <- df %>% 
    group_by(list_id, point) %>% 
    summarise(
      mean = mean(value),
      lwr = quantile(value, 0.025),
      upr = quantile(value, 0.975)
    )
  v_x <- edpxy
  summary_df$point_label <- as.factor(v_x[summary_df$point])
  pics_mi<-
    ggplot(summary_df, aes(x = point_label, y = mean, group = factor(list_id), color = factor(list_id)),linetype = factor(list_id)) +
    geom_line(position = position_dodge(0.2)) +
    geom_point(position = position_dodge(0.2)) +
    geom_errorbar(aes(ymin = lwr, ymax = upr), width = 0.2, position = position_dodge(0.2),
                  linetype = 5) + 
    labs(title = '',
         x = expression(epsilon), y = "Misclassification rate", color="") +
    scale_color_discrete(labels = list_id_labels) + 
    scale_y_continuous(labels = percent_format(scale =100)) +
    theme_minimal() +
    coord_cartesian(ylim = ry)+
    theme(
      panel.border = element_rect(color = "black", fill = NA, size = 1),
      axis.text.x = element_text(color = "black"),
      axis.text.y = element_text(color = "black"),
      plot.title = element_text(size = 10),
      legend.position=c(1, 1), legend.justification=c(1, 1)
    )
  return(pics_mi)
}
pics<-lapply(miall,function(mi){
  nbalance<-lapply(myres[-c(6:8)],function(myi) lapply(myi,function(myj)
  {ri<-unlist(lapply(myj,function(x) x[[2]][cuti,4,mi]));return(ri[!is.na(ri)])} ) )
  return(plotfun(nbalance,mi,list_id_labels=list_id_labels,ry=c(0.05,0.85)))
})
for(i in 1:4){plot(pics[[i]])}

#### A.4 sample size of weak classifier----
my0<-myres;corr<-T
res_pert<-lapply(5:8,function(j){
  myres<-my0[[j]];i=8
  # weak classifier
  vini<-lapply(myres,function(myi) unlist(lapply(myi,function(myj) myj[[1]][i,3+corr])) )
  # MA
  vpert1<-lapply(myres,function(myi) {
    x=unlist(lapply(myi,function(myj) myj[[2]][cuti,3+corr,i]));return(x[!is.na(x)])} )
  # MRMA
  vpert2<-lapply(myres,function(myi) {
    x=unlist(lapply(myi,function(myj) myj[[3]][cuti,3+corr,i]));return(x[!is.na(x)])} )
  # MR
  vpert3<-lapply(myres,function(myi) {
    x=unlist(lapply(myi,function(myj) myj[[4]][cuti,3+corr,i]));return(x[!is.na(x)])} )
  
  res<-list(vini,vpert3,vpert1,vpert2)
  res<-lapply(res,function(resi) lapply(resi,function(rii) mean(rii,na.rm=T)))
  res<-matrix(unlist(res),nrow=8)
  return(res)
})
res_pert<-abind(res_pert,along=3)
n_epsilons <- dim(res_pert)[1]
n_groups <- dim(res_pert)[2]
n_cases <- dim(res_pert)[3]
df <- expand.grid(epsilon = 1:n_epsilons, group = 1:n_groups, case = 5:(4+n_cases))
df$value <- c(res_pert)
group_labels <- c('Weak','MR','MA','MRMA')
case_labels <- paste("Case", 5:8)
df$group <- factor(df$group, labels = group_labels)
df$case <- factor(df$case, labels = case_labels)

mc <- c("#F8766D", "#00BA38", "#FFA500")
mc<-c("#00BFC4", "#F8766D", "#00BA38", "#FFA500")
mpch <- c(1, 2, 3, 5)
mlty <- c(1, 1, 1,1)

ggplot_object <- ggplot(df, aes(x = factor(epsilon), y = value, group = interaction(group, case))) +
  geom_line(aes(linetype = group, color = group)) +
  geom_point(aes(shape = case, color = group)) +
  labs(title = '', x = expression(epsilon), y = "Misclassification rate") +
  scale_y_continuous(name = "Misclassification rate", limits = c(0.095, 0.505)) +
  scale_color_manual(values = mc,name=NULL) +
  scale_shape_manual(values = mpch, labels = case_labels,name=NULL) +
  scale_linetype_manual(values = mlty, labels = group_labels,name=NULL) +
  scale_y_continuous(
    breaks = (1:5)/10,
    labels = scales::percent_format(accuracy = 1)((1:5)/10)
  ) +
  theme_minimal() +
  theme(
    legend.position = "right",
    legend.background = element_blank(),
    legend.box.background = element_blank(),
    panel.border = element_rect(color = "black", fill = NA, size = 1),
    legend.key.width = unit(1, 'cm')
  ) +
  scale_x_discrete(
    breaks = 1:8,
    labels = as.numeric(edpxy),
    name = expression(epsilon)
  ) +
  labs(colour = "Group", shape = "Case", linetype = "Group")+
  guides(
    shape = guide_legend(order = 1),
    linetype = guide_legend(order = 2, override.aes = list(color = mc)),
    color = guide_legend(order = 2, override.aes = list(shape=NA))
  )
print(ggplot_object)


#### A.6 multi_server ----
load("multi_server.Rdata")
cuti<-3;cutj<-2
fl<-lapply(miall,function(mi){
  fl_single<-lapply(myres,function(myi){
    ri<-abind(lapply(myi,function(myj) (myj[[1]])[cuti,4,mi,]),along=2) 
    ri1<-as.vector(ri[c(1:10,16:25),]);ri1<-ri1[!is.na(ri1)]
    return(ri1)
  })
  fl_avg<-lapply(myres,function(myi){
    ri<-abind(lapply(myi,function(myj)unlist(lapply(myj[[2]],function(myk) myk[10+cutj,cuti,7,mi]))),along=2) 
    ri1<-as.vector(ri[c(1:10,16:25),]);ri1<-ri1[!is.na(ri1)]
    return(ri1)
  })
  return(list(fl_single,fl_avg))
})
list_id_labels<-c('Single Server','Federated Learning')
pics<-lapply(1:4,function(mi) plotfun(fl[[mi]],mi,list_id_labels) )
for(i in 1:4){plot(pics[[i]])}








