library(cowplot)
library(metafolio)
library(ggplot2)
path= "C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/sim/kim_et_all_exp/"
# Define color palette
gg_color_hue <- function(n) {
  hues <- seq(15, 375, length = n + 1)
  hcl(h = hues, l = 65, c = 100)[1:n]
}


newlist <- readRDS(paste0(path,"fig_mr.ash.RDS"))
#```

### High-dim

#```{r fig1_1, fig.height=11, fig.width=15, fig.align = "center"}
res_df = readRDS(paste0(path,"scenario1.RDS"))
out = matrix(0,8,13)
lower = matrix(0,8,13)
upper = matrix(0,8,13)
for (i in 1:8) {
  out[i,] = colMeans(matrix(res_df[[i]]$pred, 20, 13))
  out[i,13] = mean(newlist[[1]][[i]]$pred[21:40])
  lower[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.1))
  upper[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.9))
}
colnames(out) = c("Mr.ASH.vanilla","varbvs","BayesB","B-Lasso","SuSiE","E-NET","Lasso","Ridge",
                  "SCAD","MCP","L0Learn","Mr.ASH.order","Mr.ASH")
ind = 1:13
out = out[,ind]
lower = lower[,ind]
upper = upper[,ind]
s_range = c(1,2,5,10,20,50,100,200)
col   = gg_color_hue(13)[1:12]
shape = c(19,17,24,25,9,3,11,4,5,7,8,1)
lt    = c(1,1,2,4,1,2,4)
df = data.frame(s = rep(s_range, length(ind)), pred = c(out), fit = rep(colnames(out), each = 8),
                lower = c(lower), upper = c(upper))
df$fit = factor(df$fit, levels =  c("Mr.ASH","E-NET","Lasso","Ridge",
                                    "SCAD","MCP","L0Learn",
                                    "varbvs","BayesB","B-Lasso","SuSiE","Mr.ASH.order","Mr.ASH.vanilla"))
df = df[df$fit %in% c("Mr.ASH","E-NET","Lasso","Ridge","SCAD","MCP","L0Learn"),]
df$size = 0.8
df$size[1:8] = 1.5
df$Method = df$fit
fig_dummy1 = ggplot(df) + geom_line(aes(x = s, y = pred, color = Method, linetype = Method), size = df$size) +
  geom_point(aes(x = s, y = pred, color = Method, shape = Method), size = 3) +
  theme_cowplot(font_size = 22) +
  scale_x_continuous(trans = "log10", breaks = c(1,2,5,10,20,50,100,200)) +
  labs(y = "predictior error (rmse / sigma)", x = "number of nonzero coefficients (s)") +
  theme(axis.line = element_blank(),
        plot.title = element_text(hjust = 0.5),
        axis.text.x  = element_text(angle = 0,hjust = 1),
        axis.title.y = element_blank(), axis.title.x = element_blank()) +
  scale_color_manual(values = col) +
  scale_shape_manual(values = shape) +
  scale_linetype_manual(values = lt) +
  scale_y_continuous(trans = "log10", breaks = c(1,1.1,1.2,1.3,1.4)) +
  coord_cartesian(ylim = c(1,1.35))
fig_main = fig_dummy1 + theme(legend.position = "none")
# suptitle  = ggdraw() + draw_label("Low-dimensional Setting", fontface  = 'bold', size = 14)
# subtitle  = ggdraw() + draw_label("Scenario: IndepGauss + PointNormal, n = 500, p = 200, s = 1-200, pve = 0.5", fontface  = 'bold', size = 14)
# fig1_1    = plot_grid(suptitle, subtitle,fig_main, ncol = 1, rel_heights = c(0.04,0.04,0.95))
# subtitle  = ggdraw() + draw_label("Low-dimensional Setting (Scenario 1)", fontface  = 'bold', size = 22)
subtitle  = ggdraw() + draw_label("Low Dimension", fontface  = 'bold', size = 22)
fig1_1    = plot_grid(subtitle, fig_main, ncol = 1, rel_heights = c(0.08,0.92))
#```

### High-dim

#```{r fig1_2, fig.height=11, fig.width=15, fig.align = "center"}
res_df =readRDS(paste0(path,"scenario2.RDS"))
out = matrix(0,7,13)
lower = matrix(0,7,13)
upper = matrix(0,7,13)
for (i in 1:7) {
  out[i,] = colMeans(matrix(res_df[[i]]$pred, 20, 13))
  out[i,13] = mean(newlist[[2]][[i]]$pred[21:40])
  lower[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.1))
  upper[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.9))
}
colnames(out) = c("Mr.ASH.vanilla","varbvs","BayesB","B-Lasso","SuSiE","E-NET","Lasso","Ridge",
                  "SCAD","MCP","L0Learn","Mr.ASH.order","Mr.ASH")
ind = 1:13
out = out[,ind]
lower = lower[,ind]
upper = upper[,ind]
s_range = c(1,5,20,100,500,2000,10000)
col   = gg_color_hue(13)[1:12]
shape = c(19,17,24,25,9,3,11,4,5,7,8,1)
df = data.frame(s = rep(s_range, length(ind)), pred = c(out), fit = rep(colnames(out), each = 7),
                lower = c(lower), upper = c(upper))
df$fit = factor(df$fit, levels =  c("Mr.ASH","E-NET","Lasso","Ridge",
                                    "SCAD","MCP","L0Learn",
                                    "varbvs","BayesB","B-Lasso","SuSiE","Mr.ASH.order","Mr.ASH.vanilla"))
df = df[df$fit %in% c("Mr.ASH","E-NET","Lasso","Ridge","SCAD","MCP","L0Learn"),]
df$size = 0.8
df$size[1:7] = 1.5
df = df[df$s %in% c(1,5,20,100,500,10000), ]
p1 = ggplot(df) + geom_line(aes(x = s, y = pred, color = fit, linetype = fit), size = df$size) +
  geom_point(aes(x = s, y = pred, color = fit, shape = fit), size = 3) +
  theme_cowplot(font_size = 22) +
  scale_x_continuous(trans = "log10", breaks = s_range) +
  labs(y = "predictior error (rmse / sigma)", x = "number of nonzero coefficients (s)") +
  theme(axis.line = element_blank(),
        plot.title = element_text(hjust = 0.5),
        axis.text.x  = element_text(angle = 0,hjust = 1),
        axis.title.y = element_blank(), axis.title.x = element_blank(),
        legend.position = "none") +
  scale_color_manual(values = col) +
  scale_shape_manual(values = shape) +
  scale_linetype_manual(values = lt) +
  scale_y_continuous(trans = "log10", breaks = c(1,1.1,1.2,1.3,1.4)) +
  coord_cartesian(ylim = c(1,1.45))
fig_main = p1
# suptitle  = ggdraw() + draw_label("High-dimensional Setting", fontface  = 'bold', size = 14)
# subtitle  = ggdraw() + draw_label("Scenario: IndepGauss + PointNormal, n = 500, p = 10000, s = 1-10000, pve = 0.5", fontface  = 'bold', size = 14)
# fig1_2    = plot_grid(suptitle, subtitle,fig_main, ncol = 1, rel_heights = c(0.04,0.04,0.95))
#subtitle  = ggdraw() + draw_label("High-dimensional Setting (Scenario 2)", fontface  = 'bold', size = 22)
subtitle  = ggdraw() + draw_label("High Dimension", fontface  = 'bold', size = 22)
fig1_2    = plot_grid(subtitle, fig_main, ncol = 1, rel_heights = c(0.08,0.92))
#```

### Point-const-signal

#```{r fig1_3, fig.height=11, fig.width=15, fig.align = "center"}
res_df = readRDS(paste0(path,"scenario3.RDS"))
out = matrix(0,8,13)
lower = matrix(0,8,13)
upper = matrix(0,8,13)
for (i in 1:8) {
  out[i,] = colMeans(matrix(res_df[[i]]$pred, 20, 13))
  out[i,13] = mean(newlist[[3]][[i]]$pred[21:40])
  lower[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.1))
  upper[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.9))
}
colnames(out) = c("Mr.ASH.vanilla","varbvs","BayesB","B-Lasso","SuSiE","E-NET","Lasso","Ridge",
                  "SCAD","MCP","L0Learn","Mr.ASH.order","Mr.ASH")
ind = 1:13
out = out[,ind]
lower = lower[,ind]
upper = upper[,ind]
s_range = c(1,2,5,10,20,50,100,200)
col   = gg_color_hue(13)[1:12]
shape = c(19,17,24,25,9,3,11,4,5,7,8,1)
df = data.frame(s = rep(s_range, length(ind)), pred = c(out), fit = rep(colnames(out), each = 8),
                lower = c(lower), upper = c(upper))
df$fit = factor(df$fit, levels =  c("Mr.ASH","E-NET","Lasso","Ridge",
                                    "SCAD","MCP","L0Learn",
                                    "varbvs","BayesB","B-Lasso","SuSiE","Mr.ASH.order","Mr.ASH.vanilla"))
df = df[df$fit %in% c("Mr.ASH","E-NET","Lasso","Ridge","SCAD","MCP","L0Learn"),]
df$size = 0.8
df$size[1:8] = 1.5
p1 = ggplot(df) + geom_line(aes(x = s, y = pred, color = fit, linetype = fit), size = df$size) +
  geom_point(aes(x = s, y = pred, color = fit, shape = fit), size = 3) +
  theme_cowplot(font_size = 22) +
  scale_x_continuous(trans = "log10", breaks = c(1,2,5,10,20,50,100,200)) +
  labs(y = "predictior error (rmse / sigma)", x = "number of nonzero coefficients (s)") +
  theme(axis.line = element_blank(),
        plot.title = element_text(hjust = 0.5),
        axis.text.x  = element_text(angle = 0,hjust = 1),
        axis.title.y = element_blank(), axis.title.x = element_blank(),
        legend.position = "none") +
  scale_color_manual(values = col) +
  scale_shape_manual(values = shape) +
  scale_linetype_manual(values = lt) +
  scale_y_continuous(trans = "log10", breaks = c(1,1.1,1.2,1.3,1.4)) +
  coord_cartesian(ylim = c(1,sqrt(2)))
fig_main = p1
# suptitle  = ggdraw() + draw_label("Sparse + Constant Signal", fontface  = 'bold', size = 14)
# subtitle  = ggdraw() + draw_label("Scenario: IndepGauss + PointConst, n = 500, p = 200, s = 1-200, pve = 0.5", fontface  = 'bold', size = 14)
# fig1_3    = plot_grid(suptitle, subtitle,fig_main, ncol = 1, rel_heights = c(0.04,0.04,0.95))
#subtitle  = ggdraw() + draw_label("PointConstant Signal (Scenario 3)", fontface  = 'bold', size = 22)
subtitle  = ggdraw() + draw_label("Point Constant Signal", fontface  = 'bold', size = 22)
fig1_3    = plot_grid(subtitle, fig_main, ncol = 1, rel_heights = c(0.08,0.92))
#```

### Strong-signal

#```{r fig1_4, fig.height=11, fig.width=15, fig.align = "center"}
res_df = readRDS(paste0(path,"scenario4.RDS"))
s_range = c(1,2,5,10,20,50,100,200)
#s_range = c(1,5,20,100,500,2000)
a = length(s_range)
out = matrix(0,a,13)
lower = matrix(0,a,13)
upper = matrix(0,a,13)
for (i in 1:a) {
  out[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, mean)
  out[i,13] = mean(newlist[[4]][[i]]$pred[21:40])
  lower[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.1))
  upper[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.9))
}
colnames(out) = c("Mr.ASH.vanilla","varbvs","BayesB","B-Lasso","SuSiE","E-NET","Lasso","Ridge",
                  "SCAD","MCP","L0Learn","Mr.ASH.order","Mr.ASH")
ind = 1:13
out = out[,ind]
lower = lower[,ind]
upper = upper[,ind]
col   = gg_color_hue(13)
shape = c(19,17,24,25,9,3,11,4,5,7,8,10,1)
df = data.frame(s = rep(s_range, length(ind)), pred = c(out), fit = rep(colnames(out), each = a),
                lower = c(lower), upper = c(upper))
df$fit = factor(df$fit, levels =  c("Mr.ASH","E-NET","Lasso","Ridge",
                                    "SCAD","MCP","L0Learn",
                                    "varbvs","BayesB","B-Lasso","SuSiE","Mr.ASH.order","Mr.ASH.vanilla"))
df = df[df$s != 500, ]
df$size = 0.8
df = df[df$fit %in% c("Mr.ASH","E-NET","Lasso","Ridge",
                      "SCAD","MCP","L0Learn"),]
df$size[1:8] = 1.5
p1 = ggplot(df) + geom_line(aes(x = s, y = pred, color = fit, linetype = fit), size = df$size) +
  geom_point(aes(x = s, y = pred, color = fit, shape = fit), size = 3) +
  theme_cowplot(font_size = 22) +
  scale_x_continuous(trans = "log10", breaks = unique(df$s)) +
  labs(y = "predictior error (rmse / sigma)", x = "number of nonzero coefficients (s)") +
  theme(axis.line = element_blank(),
        plot.title = element_text(hjust = 0.5),
        axis.text.x  = element_text(angle = 0,hjust = 1),
        axis.title.y = element_blank(), axis.title.x = element_blank(),
        legend.position = "none") +
  scale_color_manual(values = col) +
  scale_shape_manual(values = shape) +
  scale_linetype_manual(values = lt) +
  scale_y_continuous(trans = "log10", breaks = c(1,1.1,1.2,1.3,1.4)) +
  coord_cartesian(ylim = c(1,sqrt(2)))
fig_main = p1
# suptitle  = ggdraw() + draw_label("High Proportion of Variance (or R squared)", fontface  = 'bold', size = 14)
# subtitle  = ggdraw() + draw_label("Scenario: IndepGauss + PointNormal, n = 500, p = 200, s = 1-200, pve = 0.9", fontface  = 'bold', size = 14)
# fig1_4    = plot_grid(suptitle, subtitle,fig_main, ncol = 1, rel_heights = c(0.04,0.04,0.95))
#subtitle  = ggdraw() + draw_label("High PVE / Signal-to-Noise Ratio (Scenario 4)", fontface  = 'bold', size = 22)
subtitle  = ggdraw() + draw_label("Strong Signal", fontface  = 'bold', size = 22)
fig1_4    = plot_grid(subtitle, fig_main, ncol = 1, rel_heights = c(0.08,0.92))
#```

### High-corr-design

#```{r fig1_5, fig.height=11, fig.width=15, fig.align = "center"}
res_df = readRDS(paste0(path,"scenario5.RDS"))
s_range = c(1,5,20,100,500,2000)
a = length(s_range)
out = matrix(0,a,13)
lower = matrix(0,a,13)
upper = matrix(0,a,13)
for (i in 1:a) {
  out[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, mean)
  out[i,13] = mean(newlist[[5]][[i]]$pred[21:40])
  lower[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.1))
  upper[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.9))
}
colnames(out) = c("Mr.ASH.vanilla","varbvs","BayesB","B-Lasso","SuSiE","E-NET","Lasso","Ridge",
                  "SCAD","MCP","L0Learn","Mr.ASH.order","Mr.ASH")
ind = 1:13
out = out[,ind]
lower = lower[,ind]
upper = upper[,ind]
col   = gg_color_hue(13)
shape = c(19,17,24,25,9,3,11,4,5,7,8,10,1)
df = data.frame(s = rep(s_range, length(ind)), pred = c(out), fit = rep(colnames(out), each = a),
                lower = c(lower), upper = c(upper))
df$fit = factor(df$fit, levels =  c("Mr.ASH","E-NET","Lasso","Ridge",
                                    "SCAD","MCP","L0Learn",
                                    "varbvs","BayesB","B-Lasso","SuSiE","Mr.ASH.order","Mr.ASH.vanilla"))
df = df[df$s != 500, ]
df$size = 0.8
df = df[df$fit %in% c("Mr.ASH","E-NET","Lasso","Ridge","SCAD","MCP","L0Learn"),]
df$size[1:5] = 1.5
p1 = ggplot(df) + geom_line(aes(x = s, y = pred, color = fit, linetype = fit), size = df$size) +
  geom_point(aes(x = s, y = pred, color = fit, shape = fit), size = 3) +
  theme_cowplot(font_size = 22) +
  scale_x_continuous(trans = "log10", breaks = unique(df$s)) +
  labs(y = "predictior error (rmse / sigma)", x = "number of nonzero coefficients (s)") +
  theme(axis.line = element_blank(),
        plot.title = element_text(hjust = 0.5),
        axis.text.x  = element_text(angle = 0,hjust = 1),
        axis.title.y = element_blank(), axis.title.x = element_blank(),
        legend.position = "none") +
  scale_color_manual(values = col) +
  scale_shape_manual(values = shape) +
  scale_linetype_manual(values = lt) +
  scale_y_continuous(trans = "log10", breaks = c(1,1.1,1.2,1.3,1.4)) +
  coord_cartesian(ylim = c(1,1.2))
fig_main = p1
# suptitle  = ggdraw() + draw_label("Equicorrelated Design (rho = 0.95)", fontface  = 'bold', size = 14)
# subtitle  = ggdraw() + draw_label("Scenario: EquicorrGauss + PointNormal, n = 500, p = 2000, s = 1-2000, pve = 0.5", fontface  = 'bold', size = 14)
# fig1_5    = plot_grid(suptitle, subtitle,fig_main, ncol = 1, rel_heights = c(0.04,0.04,0.95))
#subtitle  = ggdraw() + draw_label("Equicorrelated Design (Scenario 5)", fontface  = 'bold', size = 22)
subtitle  = ggdraw() + draw_label("Equicorr (ρ = 0.95)", fontface  = 'bold', size = 22)
fig1_5    = plot_grid(subtitle, fig_main, ncol = 1, rel_heights = c(0.08,0.92))
#```

### Realgenotype-design

#```{r fig1_6, fig.height=11, fig.width=15, fig.align = "center"}
res_df = readRDS(paste0(path,"scenario6.RDS"))
s_range = c(1,5,20,100,500)
a = length(s_range)
out = matrix(0,a,13)
lower = matrix(0,a,13)
upper = matrix(0,a,13)
for (i in 1:a) {
  out[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, mean)
  out[i,13] = mean(newlist[[6]][[i]]$pred[21:40])
  lower[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.1))
  upper[i,] = apply(matrix(res_df[[i]]$pred, 20, 13), 2, function(x) quantile(x, probs = 0.9))
}
colnames(out) = c("Mr.ASH.vanilla","varbvs","BayesB","B-Lasso","SuSiE","E-NET","Lasso","Ridge",
                  "SCAD","MCP","L0Learn","Mr.ASH.order","Mr.ASH")
ind = 1:13
out = out[,ind]
lower = lower[,ind]
upper = upper[,ind]
col   = gg_color_hue(13)
shape = c(19,17,24,25,9,3,11,4,5,7,8,10,1)
df = df[df$s != 50, ]
df = data.frame(s = rep(s_range, length(ind)), pred = c(out), fit = rep(colnames(out), each = a),
                lower = c(lower), upper = c(upper))
df$fit = factor(df$fit, levels =  c("Mr.ASH","E-NET","Lasso","Ridge",
                                    "SCAD","MCP","L0Learn",
                                    "varbvs","BayesB","B-Lasso","SuSiE","Mr.ASH.order","Mr.ASH.vanilla"))
df$size = 0.8
df = df[df$fit %in% c("Mr.ASH","E-NET","Lasso","Ridge",
                      "SCAD","MCP","L0Learn"),]
df$size[1:5] = 1.5
p1 = ggplot(df) + geom_line(aes(x = s, y = pred, color = fit, linetype = fit), size = df$size) +
  geom_point(aes(x = s, y = pred, color = fit, shape = fit), size = 3) +
  theme_cowplot(font_size = 22) +
  scale_x_continuous(trans = "log10", breaks = unique(df$s)) +
  labs(y = "prediction error (rmse / sigma)", x = "number of nonzero coefficients (s)") +
  theme(axis.line = element_blank(),
        plot.title = element_text(hjust = 0.5),
        axis.text.x  = element_text(angle = 0,hjust = 1),
        axis.title.y = element_blank(), axis.title.x = element_blank(),
        legend.position = "none") +
  scale_color_manual(values = col) +
  scale_shape_manual(values = shape) +
  scale_linetype_manual(values = lt) +
  scale_y_continuous(trans = "log10", breaks = c(1,1.1,1.2,1.3,1.4)) +
  coord_cartesian(ylim = c(1,1.4))
fig_main = p1
#suptitle  = ggdraw() + draw_label("Equicorrelated Design (rho = 0.95)", fontface  = 'bold', size = 14)
#subtitle  = ggdraw() + draw_label("Scenario: EquicorrGauss + PointNormal, n = 500, p = 200, s = 1-200, pve = 0.5", fontface  = 'bold', size = 14)
#fig1_6    = plot_grid(suptitle, subtitle,fig_main, ncol = 1, rel_heights = c(0.04,0.04,0.95))
#subtitle  = ggdraw() + draw_label("Real Genotype Design (Scenario 6)", fontface  = 'bold', size = 22)
subtitle  = ggdraw() + draw_label("Real Genotype", fontface  = 'bold', size = 22)
fig1_6    = plot_grid(subtitle, fig_main, ncol = 1, rel_heights = c(0.08,0.92))
#```

## Draw Figure

#```{r figure2}
figure2  = plot_grid(fig1_1, fig1_2, fig1_3, fig1_4, fig1_5, fig1_6, ncol = 2, rel_widths = c(0.5,0.5))
#title    = ggdraw() + draw_label("Adaptation to Sparsity (Penalized Linear Regression)", fontface = 'bold', size = 22)
legend   <- get_legend(fig_dummy1 +
                         theme_cowplot(font_size = 22) +
                         theme(legend.box.margin = margin(0, 0, 0, 12),
                               legend.key.size = unit(1.2, "cm")))
yaxis    = ggdraw() + draw_label("prediction error (rmse / sigma)", size = 25, angle = 90)
xaxis    = ggdraw() + draw_label("number of nonzero coefficients (s)", size = 25, angle = 0)
figure2  = plot_grid(plot_grid(plot_grid(yaxis, plot_grid(figure2, xaxis,
                                                          rel_heights = c(0.97,0.03), ncol = 1),
                                         rel_widths = c(0.03,0.97), nrow = 1),
                               ncol = 1, rel_heights = c(0.05,0.95)),
                     legend, nrow = 1, rel_widths = c(0.9,0.1))
ggsave("figure2_for_paper.pdf", figure2, width = 18, height = 22, device=cairo_pdf)
#```
