
################################################
### Simulation for Relative Error Estimation ###
################################################

relative_error <- function(est, truth) {
  norm(est - truth, "F") / norm(truth, "F")
}

run_simulation_for_sd <- function(sd_perturb, 
                                  p = 200, 
                                  n_0 = 160, n_A = 300, 
                                  S_total = 12, S_values = c(0, 3, 6, 9, 12),
                                  lambda_j = NULL, lambda_delta = NULL, 
                                  num_simulations = 100,
                                  graph_type) {
  
  results <- matrix(NA, nrow = length(S_values), ncol = 4)
  colnames(results) <- c("Oracle-Trans-Ising", "Trans-Ising", "Naive-LogLasso", "Pooled-Trans-Ising")
  
  for (i in seq_along(S_values)) {
    S_inf <- S_values[i]
    S_irrel <- S_total - S_inf
    
    rel_err_oracle <- numeric(num_simulations)
    rel_err_trans <- numeric(num_simulations)
    rel_err_naive <- numeric(num_simulations)
    rel_err_naive_trans <- numeric(num_simulations)
    
    for (sim in 1:num_simulations) {
      cat(sprintf("Perturbation = %.2f, S_inf = %d: Running Simulation #%d / %d\n", sd_perturb, S_inf, sim, num_simulations))
      
      theta_true <- generate_theta_true(p, graph_type = graph_type)
      x_0 <- IsingSampler(n_0, theta_true, thresholds = -rowSums(theta_true) / 2)
      
      x_A_inf <- lapply(seq_len(S_inf), function(s) {
        perturb <- matrix(rnorm(p^2, mean = 0, sd = sd_perturb), p, p)
        theta_perturb <- (theta_true + perturb + t(theta_true + perturb)) / 2
        diag(theta_perturb) <- 0
        IsingSampler(n_A, theta_perturb, thresholds = -rowSums(theta_perturb) / 2)
      })
      
      x_A_irrel <- lapply(seq_len(S_irrel), function(s) {
        theta_rand <- generate_theta_true_2(p)
        IsingSampler(n_A, theta_rand, thresholds = -rowSums(theta_rand) / 2)
      })
      
      x_A_list <- c(x_A_inf, x_A_irrel)
      
      if (S_inf == 0) {
        x_A_oracle <- list()
      } else {
        x_A_oracle <- x_A_list[1:S_inf]
      }
      result_oracle <- oracle_trans_loglasso(x_A_oracle, x_0, lambda_j, lambda_delta)      
      rel_err_oracle[sim] <- relative_error(result_oracle$beta_hat, theta_true)
      
      result_trans <- trans_loglasso(x_A_list, x_0, lambda_j, lambda_delta)
      rel_err_trans[sim] <- relative_error(result_trans$beta_hat, theta_true)
      
      theta_naive <- naive_loglasso(x_0, lambda_delta)
      rel_err_naive[sim] <- relative_error(theta_naive, theta_true)
      
      result_naive_trans <- oracle_trans_loglasso(x_A_list, x_0, lambda_j, lambda_delta)
      rel_err_naive_trans[sim] <- relative_error(result_naive_trans$beta_hat, theta_true)
    }
    
    results[i, ] <- c(mean(rel_err_oracle), mean(rel_err_trans),
                      mean(rel_err_naive), mean(rel_err_naive_trans))
  }

  pert_level <- switch(as.character(sd_perturb),
                       "0.01" = "perturbation level = 1",
                       "0.1"  = "perturbation level = 10",
                       "0.2"  = "perturbation level = 20",
                       paste0("perturbation level = ", sd_perturb))
  
  df <- data.frame(
    S = rep(S_values, times = 4),
    RelativeError = as.vector(results),
    Method = rep(c("Oracle-Trans-Ising", "Trans-Ising", "Naive-LogLasso", "Pooled-Trans-Ising"), each = length(S_values)),
    Perturbation = pert_level
  )
  return(df)
}

sd_levels <- c(0.01, 0.1, 0.2)

create_perturbation_plot <- function(df, g_type) {
  
  p1 <- ggplot(subset(df, Perturbation == "perturbation level = 1"),
               aes(x = S, y = RelativeError, color = Method, shape = Method, linetype = Method)) +
    geom_line(size = 1.2) +
    geom_point(size = 5.5) +
    scale_x_continuous(breaks = c(0, 3, 6, 9, 12)) +
    scale_linetype_manual(values = c("Oracle-Trans-Ising" = "dashed", "Trans-Ising" = "dashed",
                                     "Naive-LogLasso" = "solid", "Pooled-Trans-Ising" = "solid")) +
    labs(x = "|S|", y = "Average Relative Error", title = "perturbation level = 1") +
    theme_minimal(base_family = "Times New Roman") +
    theme(text = element_text(size = 25),
          plot.title = element_text(hjust = 0.5, size = 25),
          legend.position = "none")
  
  p2 <- ggplot(subset(df, Perturbation == "perturbation level = 10"),
               aes(x = S, y = RelativeError, color = Method, shape = Method, linetype = Method)) +
    geom_line(size = 1.2) +
    geom_point(size = 5.5) +
    scale_x_continuous(breaks = c(0, 3, 6, 9, 12)) +
    scale_linetype_manual(values = c("Oracle-Trans-Ising" = "dashed", "Trans-Ising" = "dashed",
                                     "Naive-LogLasso" = "solid", "Pooled-Trans-Ising" = "solid")) +
    labs(x = "|S|", y = "Average Relative Error", title = "perturbation level = 10") +
    theme_minimal(base_family = "Times New Roman") +
    theme(text = element_text(size = 25),
          plot.title = element_text(hjust = 0.5, size = 25),
          legend.position = "none")
  
  p3 <- ggplot(subset(df, Perturbation == "perturbation level = 20"),
               aes(x = S, y = RelativeError, color = Method, shape = Method, linetype = Method)) +
    geom_line(size = 1.2) +
    geom_point(size = 5.5) +
    scale_x_continuous(breaks = c(0, 3, 6, 9, 12)) +
    scale_linetype_manual(values = c("Oracle-Trans-Ising" = "dashed", "Trans-Ising" = "dashed",
                                     "Naive-LogLasso" = "solid", "Pooled-Trans-Ising" = "solid")) +
    labs(x = "|S|", y = "Average Relative Error", title = "perturbation level = 20") +
    theme_minimal(base_family = "Times New Roman") +
    theme(text = element_text(size = 25),
          plot.title = element_text(hjust = 0.5, size = 25),
          legend.position = "none")
  
  combined_plot <- (p1 + p2 + p3 + plot_layout(guides = "collect")) &
    theme(legend.position = "bottom")
  
  final_plot <- combined_plot
  
  return(final_plot)
}


graph_types_to_test <- c("sparse_block_star")

for (g_type in graph_types_to_test) {
  
  cat(sprintf("\n>>> graph type: %s simulation start...\n", g_type))
  
  df_result <- do.call(rbind, lapply(sd_levels, run_simulation_for_sd, graph_type = g_type))
  
  output_filename <- paste0("simulation_results_", g_type, ".rds")
  
  saveRDS(df_result, file = output_filename)
  
  cat(sprintf(">>> graph type: %s simultation complete. results saved at '%s' \n", g_type, output_filename))
}

cat("\n--- All simultations and file save complete ---\n")




### making plots of relative error ###

create_perturbation_plot_component <- function(df, g_type) {
  
  main_plot_title <- switch(g_type,
                            "sparse_block_star" = "Sparse Block Star Graph",
                            "block_star"        = "Block-wise Star Graph",
                            "chain"             = "Chain Graph",
                            "random"            = "Random Sparse Graph",
                            g_type)
  
  my_theme <- theme_minimal(base_family = "Times New Roman") +
    theme(
      text = element_text(size = 22),
      axis.title = element_text(size = 20),
      axis.text = element_text(size = 16),
      
      legend.position = "bottom",
      legend.title = element_blank(),
      legend.text = element_text(size = 24),
      legend.key.size = unit(1.5, "cm"),
      legend.box.margin = margin(t = 20, b = 10),
      plot.title = element_text(hjust = 0.5, face = "bold", size = 27),
      plot.subtitle = element_text(hjust = 0.5, face = "plain", size = 18, margin = margin(b=10))
    )
  
  p1 <- ggplot(subset(df, Perturbation == "perturbation level = 1"),
               aes(x = S, y = RelativeError, color = Method, shape = Method, linetype = Method)) +
    geom_line(size = 1.2) + geom_point(size = 4) +
    scale_x_continuous(breaks = c(0, 3, 6, 9, 12)) +
    scale_linetype_manual(values = c("Oracle-Trans-Ising"="dashed", "Trans-Ising"="dashed",
                                     "Naive-LogLasso"="solid", "Pooled-Trans-Ising"="solid")) +
    labs(x = "|A|", y = "Avg Relative Error", 
         title = NULL, 
         subtitle = "perturbation = 1") + 
    my_theme
  
  p2 <- ggplot(subset(df, Perturbation == "perturbation level = 10"),
               aes(x = S, y = RelativeError, color = Method, shape = Method, linetype = Method)) +
    geom_line(size = 1.2) + geom_point(size = 4) +
    scale_x_continuous(breaks = c(0, 3, 6, 9, 12)) +
    scale_linetype_manual(values = c("Oracle-Trans-Ising"="dashed", "Trans-Ising"="dashed",
                                     "Naive-LogLasso"="solid", "Pooled-Trans-Ising"="solid")) +
    labs(x = "|A|", y = "Avg Relative Error", 
         title = main_plot_title,           
         subtitle = "perturbation = 10") + 
    my_theme +
    theme(axis.title.y = element_blank())   
  
  p3 <- ggplot(subset(df, Perturbation == "perturbation level = 20"),
               aes(x = S, y = RelativeError, color = Method, shape = Method, linetype = Method)) +
    geom_line(size = 1.2) + geom_point(size = 4) +
    scale_x_continuous(breaks = c(0, 3, 6, 9, 12)) +
    scale_linetype_manual(values = c("Oracle-Trans-Ising"="dashed", "Trans-Ising"="dashed",
                                     "Naive-LogLasso"="solid", "Pooled-Trans-Ising"="solid")) +
    labs(x = "|A|", y = "Avg Relative Error", 
         title = NULL, 
         subtitle = "perturbation = 20") + 
    my_theme +
    theme(axis.title.y = element_blank())
  
  strip_plot <- p1 + p2 + p3 + plot_layout(nrow = 1)
  
  return(strip_plot)
}


g_types <- c("random", "block_star", "chain", "sparse_block_star")
plots_list <- list()

for (g_type in g_types) {
  results_file <- paste0("simulation_results_", g_type, ".rds")
  if (file.exists(results_file)) {
    df_temp <- readRDS(results_file)
    plots_list[[length(plots_list) + 1]] <- create_perturbation_plot_component(df_temp, g_type)
  }
}

combined_final_plot <- wrap_plots(plots_list, ncol = 2) + 
  plot_layout(guides = "collect") +
  plot_annotation(theme = theme(legend.position = "bottom"))

output_filename <- "2plot_combined_2x2_fixed_subtitle.png"

ggsave(
  filename = output_filename,
  plot = combined_final_plot,
  width = 40,   
  height = 20,
  units = "cm",
  dpi = 500
)




###########################
### PR curve simulation ###
###########################

run_roc_simulation <- function(p = 200, n_0 = 160, n_A = 300,
                               sd_perturb = 0.2,
                               lambda_j = NULL, lambda_delta = NULL,
                               num_sim = 1,
                               graph_type) {
  roc_df_list <- list()
  
  for (sim in 1:num_sim) {
    message(sprintf("Graph: %s, Simulation: %d/%d", graph_type, sim, num_sim))
    
    theta_true <- generate_theta_true(p, graph_type = graph_type)
    x_0 <- IsingSampler(n_0, theta_true, thresholds = -rowSums(theta_true)/2)
    
    x_A_inf <- lapply(1:3, function(s) {
      perturb <- matrix(rnorm(p^2, 0, sd_perturb), p, p)
      theta_inf <- (theta_true + perturb + t(theta_true + perturb)) / 2
      diag(theta_inf) <- 0
      IsingSampler(n_A, theta_inf, thresholds = -rowSums(theta_inf)/2)
    })
    
    x_A_irrel <- lapply(1:9, function(s) {
      theta_bad <- generate_theta_true_2(p)
      IsingSampler(n_A, theta_bad, thresholds = -rowSums(theta_bad)/2)
    })
    
    x_A_all <- c(x_A_inf, x_A_irrel)
    
    theta_estimates <- list(
      "Oracle-Trans-Ising" = oracle_trans_loglasso(x_A_inf, x_0, lambda_j, lambda_delta)$beta_hat,
      "Trans-Ising"        = trans_loglasso(x_A_all, x_0, lambda_j, lambda_delta)$beta_hat,
      "Naive-LogLasso"     = naive_loglasso(x_0, lambda_delta),
      "Pooled-Trans-Ising" = oracle_trans_loglasso(x_A_all, x_0, lambda_j, lambda_delta)$beta_hat
    )
    
    sim_results <- lapply(names(theta_estimates), function(method_name) {
      est <- theta_estimates[[method_name]]
      roc_obj <- roc(as.vector(theta_true != 0), abs(as.vector(est)), quiet = TRUE)
      
      pr_data <- coords(roc_obj, "all", ret = c("recall", "precision")) %>%
        as_tibble() %>% 
        rename(Recall = recall, Precision = precision) %>% 
        add_column(Method = method_name, Sim = sim) %>% 
        filter(is.finite(Precision)) %>% 
        arrange(Recall) 
      
      return(pr_data)
    })
    
    roc_df_list[[sim]] <- do.call(rbind, sim_results)
  }
  
  do.call(rbind, roc_df_list)
}

process_roc_data <- function(raw_roc_df) {
  avg_roc_df <- raw_roc_df %>%
    mutate(FPR_bin = cut(FPR, breaks = seq(0, 1, length.out = 101), include.lowest = TRUE)) %>%
    group_by(Method, FPR_bin) %>%
    summarise(
      FPR = mean(FPR, na.rm = TRUE),
      TPR = mean(TPR, na.rm = TRUE),
      AUC = mean(AUC, na.rm = TRUE),
      .groups = "drop"
    )
  
  start_points <- avg_roc_df %>%
    group_by(Method) %>%
    summarise(
      FPR = 0,
      TPR = 0,
      AUC = mean(AUC, na.rm = TRUE)
    )
  
  bind_rows(start_points, avg_roc_df) %>%
    arrange(Method, FPR)
}


create_roc_plot <- function(processed_df, plot_title) {
  ggplot(processed_df, aes(x = FPR, y = TPR, color = Method)) +
    geom_line(size = 1.5) + 
    labs(title = plot_title,
         x = "1 - Specificity",
         y = "Sensitivity",
         color = "Method") +
    theme_minimal() +
    theme(
      plot.title = element_text(hjust = 0.5, size = 20),
      legend.title = element_text(size = 16),
      legend.text = element_text(size = 14),
      axis.title = element_text(size = 16),
      axis.text = element_text(size = 14)
    ) +
    geom_text(
      data = processed_df %>% group_by(Method) %>% summarise(AUC = mean(AUC)) %>%
        mutate(x = 0.65, y = seq(0.4, 0.1, length.out = n())), 
      aes(x = x, y = y, label = paste0("AUC = ", round(AUC, 3)), color = Method),
      size = 6, show.legend = FALSE
    )
}

graph_configs <- list(
  "block_star"        = "Block-wise Star graph",
  "sparse_block_star" = "Sparse Block Star Graph",
  "chain"             = "Chain Graph",
  "random"            = "Random Sparse Graph"
)

perturb_level <- 0.01
total_simulations <- 100 

cat(sprintf("--- Simulation Start (sd_perturb = %.2f) ---\n", perturb_level))

for (g_type in names(graph_configs)) {
  
  cat(sprintf("\n>>> Graph Type: %s Simulation Start...\n", g_type))
  

  raw_df <- run_roc_simulation(graph_type = g_type, sd_perturb = perturb_level, num_sim = total_simulations)
  
  output_filename <- sprintf("sim_results_pr_g%.2f_%s.rds", perturb_level, g_type)
  saveRDS(raw_df, file = output_filename)
  
  cat(sprintf(">>> Graph Type: %s Simulation Complete. Saved to '%s' \n", g_type, output_filename))
}

cat(sprintf("\n--- All Complete (sd_perturb = %.2f) ---\n", perturb_level))





### PR making plots ###

library(glmnet)
library(dplyr)
library(ggplot2)
library(patchwork)
library(tibble)

.monotone_and_grid <- function(df_one_sim, grid = seq(0, 1, length.out = 501)) {
  df2 <- df_one_sim %>%
    arrange(Recall) %>%
    distinct(Recall, .keep_all = TRUE) %>%
    mutate(Precision = rev(cummax(rev(Precision)))) 
  
  pr <- approx(x = df2$Recall, y = df2$Precision, xout = grid,
               ties = "ordered", rule = 1)$y
  tibble(Recall = grid, Precision = pr)
}

process_pr_data <- function(raw_pr_df, grid = seq(0, 1, length.out = 501)) {
  if (!("Sim" %in% names(raw_pr_df))) raw_pr_df$Sim <- 1L
  
  mono_grid <- raw_pr_df %>%
    group_by(Method, Sim) %>%
    group_modify(~ .monotone_and_grid(.x[, c("Recall", "Precision")], grid = grid)) %>%
    ungroup()
  
  pr_avg <- mono_grid %>%
    group_by(Method, Recall) %>%
    summarise(Precision = mean(Precision, na.rm = TRUE), .groups = "drop")
  
  list(pr_curve = pr_avg)
}


create_pr_plot <- function(processed, main_title = NULL, sub_title = NULL) {
  pr_df <- processed$pr_curve
  
  p <- ggplot(pr_df, aes(x = Recall, y = Precision, color = Method, linetype = Method)) +
    geom_line(size = 1.2, na.rm = TRUE) + 
    
    scale_linetype_manual(values = c(
      "Oracle-Trans-Ising" = "solid",
      "Trans-Ising"        = "dashed",
      "Naive-LogLasso"     = "solid",
      "Pooled-Trans-Ising" = "solid"
    )) +
    
    scale_color_manual(values = c(
      "Oracle-Trans-Ising" = "black",  
      "Trans-Ising"        = "cyan",   
      "Naive-LogLasso"     = "forestgreen",      
      "Pooled-Trans-Ising" = "red"       
    )) +
    
    labs(title = main_title,
         subtitle = sub_title,
         x = "Recall",
         y = "Precision",
         
         color = "Method",
         linetype = "Method") +
    
    coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
    theme_minimal(base_family = "Times New Roman") + 
    theme(
      
      text = element_text(size = 27), 
      plot.title = element_text(hjust = 0.5, size = 27, face = "bold"),
      plot.subtitle = element_text(hjust = 0.5, size = 25, margin = margin(b = 10)),
      axis.title = element_text(size = 25),
      axis.text = element_text(size = 18),
      
      
      legend.position = "bottom",
      legend.title = element_blank(),       
      legend.text = element_text(size = 30),
      legend.key.size = unit(1.9, "cm"),    
      legend.margin = margin(t = 20)
    )
  
  return(p)
}


sd_levels <- c(0.01, 0.1, 0.2)
graph_configs <- list(
  "block_star"        = "Block-wise Star",
  "sparse_block_star" = "Sparse Block Star",
  "chain"             = "Chain",
  "random"            = "Random Sparse"
)

all_plots <- list()

for (sd_val in sd_levels) {
  
  perturb_text <- switch(as.character(sd_val),
                         "0.01" = "perturbation = 1",
                         "0.1"  = "perturbation = 10",
                         "0.2"  = "perturbation = 20",
                         paste("perturbation =", sd_val))
  
  for (g_type in names(graph_configs)) {
    
    results_file <- sprintf("sim_results_pr_g%.2f_%s.rds", sd_val, g_type)
    
    if (file.exists(results_file)) {
      raw_df <- readRDS(results_file)
      processed <- process_pr_data(raw_df)
      
      current_main_title <- if (sd_val == 0.01) graph_configs[[g_type]] else NULL
      
      p <- create_pr_plot(processed, 
                          main_title = current_main_title, 
                          sub_title = perturb_text)
      
      all_plots[[length(all_plots) + 1]] <- p
      
    } else {
      warning(sprintf("File missing: %s", results_file))
      all_plots[[length(all_plots) + 1]] <- ggplot() + theme_void()
    }
  }
}


if (length(all_plots) > 0) {
  
  final_combined_plot <- wrap_plots(all_plots, ncol = 4) +
    plot_layout(guides = "collect") + 
    plot_annotation(theme = theme(legend.position = "bottom"))
  
  output_filename <- "final_pr_curves_integrated.png"
  
  ggsave(filename = output_filename, 
         plot = final_combined_plot,
         width = 40, height = 30, units = "cm", dpi = 600)
  
  cat(sprintf("Saved integrated PR plot to '%s'\n", output_filename))
  
} else {
  cat("No plots created.\n")
}





############################
### ROC curve simulation ###
############################

run_roc_simulation <- function(p = 200, n_0 = 160, n_A = 300,
                               sd_perturb = 0.2,
                               lambda_j = NULL, lambda_delta = NULL,
                               num_sim = 1,
                               graph_type) { 
  roc_df_list <- list()
  
  for (sim in 1:num_sim) {
    message(sprintf("Graph: %s, Simulation: %d/%d", graph_type, sim, num_sim))
    
    theta_true <- generate_theta_true(p, graph_type = graph_type)
    x_0 <- IsingSampler(n_0, theta_true, thresholds = -rowSums(theta_true)/2)
    
    x_A_inf <- lapply(1:3, function(s) {
      perturb <- matrix(rnorm(p^2, 0, sd_perturb), p, p)
      theta_inf <- (theta_true + perturb + t(theta_true + perturb)) / 2
      diag(theta_inf) <- 0
      IsingSampler(n_A, theta_inf, thresholds = -rowSums(theta_inf)/2)
    })
    
    x_A_irrel <- lapply(1:9, function(s) {
      theta_bad <- generate_theta_true_2(p)
      IsingSampler(n_A, theta_bad, thresholds = -rowSums(theta_bad)/2)
    })
    
    x_A_all <- c(x_A_inf, x_A_irrel)
    
    theta_estimates <- list(
      "Oracle-Trans-Ising" = oracle_trans_loglasso(x_A_inf, x_0, lambda_j, lambda_delta)$beta_hat,
      "Trans-Ising"        = trans_loglasso(x_A_all, x_0, lambda_j, lambda_delta)$beta_hat,
      "Naive-LogLasso"     = naive_loglasso(x_0, lambda_delta),
      "Pooled-Trans-Ising" = oracle_trans_loglasso(x_A_all, x_0, lambda_j, lambda_delta)$beta_hat
    )
    
    sim_results <- lapply(names(theta_estimates), function(method_name) {
      est <- theta_estimates[[method_name]]
      roc_obj <- roc(as.vector(theta_true != 0), abs(as.vector(est)), quiet = TRUE)
      data.frame(
        FPR = 1 - roc_obj$specificities,
        TPR = roc_obj$sensitivities,
        Method = method_name,
        AUC = as.numeric(auc(roc_obj)),
        Sim = sim
      )
    })
    
    roc_df_list[[sim]] <- do.call(rbind, sim_results)
  }
  
  do.call(rbind, roc_df_list)
}

process_roc_data <- function(raw_roc_df) {
  
  avg_roc_df <- raw_roc_df %>%
    mutate(FPR_bin = cut(FPR, breaks = seq(0, 1, length.out = 101), include.lowest = TRUE)) %>%
    group_by(Method, FPR_bin) %>%
    summarise(
      FPR = mean(FPR, na.rm = TRUE),
      TPR = mean(TPR, na.rm = TRUE),
      AUC = mean(AUC, na.rm = TRUE),
      .groups = "drop"
    )
  
  start_points <- avg_roc_df %>%
    group_by(Method) %>%
    summarise(
      FPR = 0,
      TPR = 0,
      AUC = mean(AUC, na.rm = TRUE)
    )
  
  bind_rows(start_points, avg_roc_df) %>%
    arrange(Method, FPR)
}


create_roc_plot <- function(processed_df, plot_title) {
  ggplot(processed_df, aes(x = FPR, y = TPR, color = Method)) +
    geom_line(size = 1.5) + 
    labs(title = plot_title,
         x = "1 - Specificity",
         y = "Sensitivity",
         color = "Method") +
    theme_minimal() +
    theme(
      plot.title = element_text(hjust = 0.5, size = 20),
      legend.title = element_text(size = 16),
      legend.text = element_text(size = 14),
      axis.title = element_text(size = 16),
      axis.text = element_text(size = 14)
    ) +
    geom_text(
      data = processed_df %>% group_by(Method) %>% summarise(AUC = mean(AUC)) %>%
        mutate(x = 0.65, y = seq(0.4, 0.1, length.out = n())), 
      aes(x = x, y = y, label = paste0("AUC = ", round(AUC, 3)), color = Method),
      size = 6, show.legend = FALSE
    )
}


graph_configs <- list(
  "block_star"        = "Block-wise Star graph",
  "sparse_block_star" = "Sparse Block Star Graph",
  "chain"             = "Chain Graph",
  "random"            = "Random Sparse Graph"
)


perturb_level <- 0.01
total_simulations <- 100 

cat(sprintf("--- Simulation Start (sd_perturb = %.2f) ---\n", perturb_level))

for (g_type in names(graph_configs)) {
  
  cat(sprintf("\n>>> Graph Type: %s Simulation Start...\n", g_type))
  
  raw_df <- run_roc_simulation(graph_type = g_type, sd_perturb = perturb_level, num_sim = total_simulations)
  
  output_filename <- sprintf("sim_results_g%.2f_%s.rds", perturb_level, g_type)
  saveRDS(raw_df, file = output_filename)
  
  cat(sprintf(">>> Graph Type: %s Simulation Complete. Save to '%s' \n", g_type, output_filename))
}

cat(sprintf("\n--- All Complete (sd_perturb = %.2f) ---\n", perturb_level))





### ROC making plots ###

.roc_monotone_and_grid <- function(df_one_sim, grid = seq(0, 1, length.out = 1001)) {
  df2 <- df_one_sim %>%
    arrange(FPR, TPR) %>%
    distinct(FPR, .keep_all = TRUE)
  
  if (df2$FPR[1] > 0) df2 <- bind_rows(tibble(FPR = 0, TPR = 0), df2)
  if (df2$FPR[nrow(df2)] < 1) df2 <- bind_rows(df2, tibble(FPR = 1, TPR = 1))
  
  df2$TPR <- cummax(df2$TPR) 
  
  tpr <- approx(x = df2$FPR, y = df2$TPR, xout = grid, ties = "ordered", rule = 2)$y
  tibble(FPR = grid, TPR = tpr)
}

process_roc_data <- function(raw_roc_df, grid = seq(0, 1, length.out = 1001)) {
  if (!("Sim" %in% names(raw_roc_df))) raw_roc_df$Sim <- 1L
  
  mono_grid <- raw_roc_df %>%
    group_by(Method, Sim) %>%
    group_modify(~ .roc_monotone_and_grid(.x[, c("FPR", "TPR")], grid = grid)) %>%
    ungroup()
  
  roc_avg <- mono_grid %>%
    group_by(Method, FPR) %>%
    summarise(TPR = mean(TPR, na.rm = TRUE), .groups = "drop")
  
  list(roc_curve = roc_avg)
}


create_roc_plot <- function(processed, main_title = NULL, sub_title = NULL) {
  roc_df <- processed$roc_curve
  
  p <- ggplot(roc_df, aes(x = FPR, y = TPR, color = Method, linetype = Method)) +
    geom_line(size = 1.2, na.rm = TRUE) +
    
    scale_linetype_manual(values = c(
      "Oracle-Trans-Ising" = "solid",
      "Trans-Ising"        = "dashed",
      "Naive-LogLasso"     = "solid",
      "Pooled-Trans-Ising" = "solid"
    )) +
    
    scale_color_manual(values = c(
      "Oracle-Trans-Ising" = "black",  
      "Trans-Ising"        = "cyan",   
      "Naive-LogLasso"     = "forestgreen",      
      "Pooled-Trans-Ising" = "red"       
    )) +
    
    labs(title = main_title,
         subtitle = sub_title,
         x = "1 - Specificity (FPR)",
         y = "Sensitivity (TPR)",
         color = "Method",
         linetype = "Method") +
    
    coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
    theme_minimal(base_family = "Times New Roman") +
    theme(
      text = element_text(size = 27),
      plot.title = element_text(hjust = 0.5, size = 27, face = "bold"),
      plot.subtitle = element_text(hjust = 0.5, size = 25, margin = margin(b = 10)),
      axis.title = element_text(size = 25),
      axis.text = element_text(size = 18),
      
      legend.position = "bottom",
      legend.title = element_blank(),       
      legend.text = element_text(size = 30),
      legend.key.size = unit(1.9, "cm"),    
      legend.margin = margin(t = 20)
    )
  
  return(p)
}


sd_levels <- c(0.01, 0.1, 0.2)
graph_configs <- list(
  "block_star"        = "Block-wise Star",
  "sparse_block_star" = "Sparse Block Star",
  "chain"             = "Chain",
  "random"            = "Random Sparse"
)

all_plots <- list()

for (sd_val in sd_levels) {
  
  perturb_text <- switch(as.character(sd_val),
                         "0.01" = "perturbation = 1",
                         "0.1"  = "perturbation = 10",
                         "0.2"  = "perturbation = 20",
                         paste("perturbation =", sd_val))
  
  for (g_type in names(graph_configs)) {
    
    results_file <- sprintf("sim_results_g%.2f_%s.rds", sd_val, g_type)
    
    if (file.exists(results_file)) {
      raw_df <- readRDS(results_file)
      processed <- process_roc_data(raw_df)
      
      current_main_title <- if (sd_val == 0.01) graph_configs[[g_type]] else NULL
      
      p <- create_roc_plot(processed, 
                           main_title = current_main_title, 
                           sub_title = perturb_text)
      
      all_plots[[length(all_plots) + 1]] <- p
      
    } else {
      warning(sprintf("File missing: %s", results_file))
      all_plots[[length(all_plots) + 1]] <- ggplot() + theme_void()
    }
  }
}


if (length(all_plots) > 0) {
  
  final_combined_plot <- wrap_plots(all_plots, ncol = 4) +
    plot_layout(guides = "collect") +
    plot_annotation(theme = theme(legend.position = "bottom"))
  
  output_filename <- "final_roc_plots_integrated.png"
  
  ggsave(filename = output_filename, 
         plot = final_combined_plot,
         width = 40, height = 30, units = "cm", dpi = 600)
  
  cat(sprintf("Saved integrated ROC plot to '%s'\n", output_filename))
  
} else {
  cat("No plots created.\n")
}
