# Load required libraries
library(ggplot2)
library(cowplot)
library(dplyr)
library(tidyr)
library(gridExtra)

# Define color palette
gg_color_hue <- function(n) {
  hues <- seq(15, 375, length = n + 1)
  hcl(h = hues, l = 65, c = 100)[1:n]
}

# Load the data
path <- "C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/sim/kim_et_all_exp/"
df <- readRDS(paste0(path, "flexibility.RDS"))
df$Operator <- df$method

# ----------- Shrinkage Operator Plots -----------

# MR.ASH/NASH shrinkage operator
p1 <- ggplot(df) +
  geom_line(aes(x = b, y = sb, color = Operator), size = 1) +
  theme_cowplot(font_size = 12) +
  labs(y = expression("Shrinkage operator " ~ S(b)),
       x= expression(b),
       title = "MR.ASH/NASH shrinkage operators") +
  theme(
        plot.title = element_text(hjust = 0.5),
        legend.position = "none") +
  coord_cartesian(xlim = c(0, 10), ylim = c(0, 10))

# Other shrinkage/thresholding operators
p2 <- ggplot(df) +
  geom_line(aes(x = b, y = sb2, color = Operator), size = 1) +
  theme_cowplot(font_size = 12) +
  labs(y = NULL,
       x= expression(b),
       title = "Other shrinkage/thresholding operators") +
  theme(
        plot.title = element_text(hjust = 0.5) ) +
  coord_cartesian(xlim = c(0, 10), ylim = c(0, 10))

# X-axis label
xaxis <- ggdraw() + draw_label(expression(b), size = 18)

# Legend from p2
legend <- get_legend(p2 + theme(legend.position = "bottom",
                                legend.title = element_text(size = 14),
                                legend.text = element_text(size = 12)))

# ----------- Fused Lasso Penalty Surface -----------

beta1 <- seq(-4,4, length.out = 200)
beta2 <- seq(-4, 4, length.out = 200)
grid <- expand.grid(beta1 = beta1, beta2 = beta2)
grid <- grid %>% mutate(penalty = abs(beta1) + abs(beta2) + abs(beta2 - beta1))

P_fussed <- ggplot(grid, aes(x = beta1, y = beta2, z = penalty)) +
  geom_raster(aes(fill = penalty)) +
  scale_fill_gradient2(low = "red", mid = "white", high = "blue",
                       midpoint = mean(grid$penalty)) +
  labs(title = "Fused Lasso penalty surface", fill = "Penalty") +
  coord_fixed() +
  xlab(expression(b [1]))+
  ylab(expression(b [2]))+
  theme_cowplot(font_size = 12)

# ----------- Prior Density from Graph Net ------------

dlaplace <- function(x, location = 0, scale = 1) {
  (1 / (2 * scale)) * exp(-abs(x - location) / scale)
}

Z_laplace_k2 <- function(lambda1, lambda2, mu1, mu2) {
  if (mu1 > mu2) {
    tmp <- mu1; mu1 <- mu2; mu2 <- tmp
    tmp <- lambda1; lambda1 <- lambda2; lambda2 <- tmp
  }
  delta <- mu2 - mu1
  if (abs(lambda1 - lambda2) > 1e-10) {
    term1 <- exp(-lambda2 * delta) / (lambda1 + lambda2)
    term2 <- (exp(-lambda1 * delta) - exp(-lambda2 * delta)) / (lambda2 - lambda1)
    term3 <- exp(-lambda1 * delta) / (lambda1 + lambda2)
    Z <- (lambda1 * lambda2 / 4) * (term1 + term2 + term3)
  } else {
    lambda <- lambda1
    Z <- (lambda / 4) * exp(-lambda * delta) * (2 + lambda * delta)
  }
  return(Z)
}

joint_density <- function(beta1, beta2, s1, s2) {
  prior_beta1 <- dlaplace(beta1, beta2, s1) * dlaplace(beta1, 0, s2)
  prior_beta2 <- dlaplace(beta2, beta1, s1) * dlaplace(beta2, 0, s2)
  density_beta1 <- prior_beta1 / Z_laplace_k2(1/s1, 1/s2, 0, beta2)
  density_beta2 <- prior_beta2 / Z_laplace_k2(1/s1, 1/s2, 0, beta1)
  density_beta1 * density_beta2
}

# Compute joint prior
grid <- expand.grid(beta1 = beta1, beta2 = beta2)
s1 <- 0.45; s2 <- 0.15

grid <- grid %>%
  rowwise() %>%
  mutate(joint_density_value = joint_density(beta1, beta2, s1, s2)) %>%
  ungroup()

P_L <- ggplot(grid, aes(x = beta1, y = beta2, fill = -log10(joint_density_value))) +
  geom_raster() +
  scale_fill_gradient2(low = "red", mid = "white", high = "blue",
                       midpoint = median(-log10(grid$joint_density_value))) +
  labs(title = "-log10 of the prior density induced by GNN-based prior",
       x = expression(beta[1]), y = expression(beta[2]), fill = expression (- log[10] *"Density" )) +
  coord_fixed() +
  xlab(expression(b [1]))+
  ylab(expression(b [2]))+
  theme_cowplot(font_size = 12)



P_L# ----------- Final Assembly ------------


top_row <- ggdraw() +
  draw_plot(p1, x = 0,   y = 0, width = 0.4 , height = 1) +
  draw_plot(p2, x = 0.523 , y = 0, width = 0.45 , height = 1)
 bottom_row <- plot_grid(P_L, P_fussed, nrow = 1, align = "hv", axis = "tblr")

main_panel <- plot_grid(top_row,   bottom_row, ncol = 1,
                        rel_heights = c(0.4,  0.6))
final_figure <- plot_grid(main_panel,   ncol =1 )

# Save figure
ggsave("figure1_for_paper.pdf", final_figure, width = 12, height = 8, device = cairo_pdf)
ggsave("figure1b_for_paper.pdf", bottom_row, width = 12, height = 5, device = cairo_pdf)


