# rm(list = ls())
#
# source('R/experiments/analysis/metrics.R')
# source('R/experiments/analysis/helper_functions.R')
#
# library(tidyverse)
#
# # Parameters
# nmc <- 1024
# nmc_causal_true <- 4096
# nobs <- 1024
#
# n_seeds <- 40
# n_to_explain <- 40
#
# n_nodes_values <- c(6, 11, 16)
# n_combinations_values <- c(32, 1024, 8192)
# n_combinations_causal_true_values <- c(32, 1024, 16384)
# dag_types <- c('2_neighbors_3_parents_y', '2_neighbors_6_parents_y', '2_neighbors_9_parents_y')
#
# # Initialize results tibble
# results <- tibble(
#   n_nodes = integer(),
#   seed = integer(),
#   explanation_type = character(),
#   idx_mec = integer(),
#   idx_datapoint = integer(),
#   l2_error = double()
# )
#
# # Helper function to add results to the tibble
# add_results <- function(results, n_nodes, seed, explanation_type, l2_errors, idx_mec = NA) {
#   results <- results %>%
#     bind_rows(tibble(
#       n_nodes = rep(n_nodes, length(l2_errors)),
#       seed = rep(seed, length(l2_errors)),
#       explanation_type = rep(explanation_type, length(l2_errors)),
#       idx_mec = rep(idx_mec, length(l2_errors)),
#       idx_datapoint = seq_along(l2_errors),
#       l2_error = l2_errors
#     ))
#     return(results)
# }
#
# # Main processing loop
# explanation_types <- c("causal_true", "marginal", "conditional", "oracle_mec_sample", "oracle_mec_iw", "pc_1024n_alpha0.05_mec_sample", "pc_1024n_alpha0.05_mec_iw", "fges_1024_mec_sample", "fges_1024_mec_iw")
#
# for (idx_n_nodes in seq_along(n_nodes_values)) {
#   n_nodes <- n_nodes_values[idx_n_nodes]
#   n_combinations <- n_combinations_values[idx_n_nodes]
#   n_combinations_causal_true <- n_combinations_causal_true_values[idx_n_nodes]
#
#   dag_type <- dag_types[idx_n_nodes]
#   dir <- file.path('experiments/6_explanations/xgboost/gaussian/', paste0(n_nodes, '_nodes'), dag_type)
#
#   nmc_explanations <- ifelse(explanation_types == "causal_true", nmc_causal_true, nmc)
#   n_combinations_explanations <- ifelse(explanation_types == "causal_true", n_combinations_causal_true, n_combinations)
#   #n_obs_values <- ifelse(explanation_types == "causal_true", NULL, nobs) # this does not work
#
#   for (seed in seq_len(n_seeds)) {
#
#     explanations <- lapply(seq_along(explanation_types), function(explanation_idx) {
#       current_explanation_type <- explanation_types[explanation_idx]
#       current_nmc <- nmc_explanations[explanation_idx]
#       current_n_combinations <- n_combinations_explanations[explanation_idx]
#       current_nobs <- if (current_explanation_type == "causal_true") NULL else nobs
#
#       return(retrieve_explanations(dir, seed, current_explanation_type, current_nobs, current_nmc, current_n_combinations,
#                                                    n_to_explain, return_shapley_values = TRUE, use_first_one = TRUE))
#     })
#
#     explanation_true <- explanations[[1]]
#     if (is.null(explanation_true)) {
#       print(paste0("Missing true causal explanation for ", n_nodes, " nodes, seed ", seed))
#       next
#     }
#
#     for (i in 2:length(explanations)) {
#       explanation <- explanations[[i]]
#       explanation_type <- explanation_types[i]
#       if (!is.null(explanation)) {
#         if (is.data.frame(explanation)) {  # Non-list case
#           l2_errors <- compute_normalized_l2(explanation_true, explanation)
#           results <- add_results(results, n_nodes, seed, explanation_type, l2_errors)
#         } else {  # List case for multiple explanations
#           for (idx_mec in seq_along(explanation)) {
#             l2_errors <- compute_normalized_l2(explanation_true, explanation[[idx_mec]])
#             results <- add_results(results, n_nodes, seed, explanation_type, l2_errors, idx_mec)
#           }
#         }
#       } else {
#         print(paste0("Missing explanation for ", explanation_type, " for ", n_nodes, " nodes, seed ", seed))
#       }
#     }
#   }
# }
#
# # rename explanation_types
# results$explanation_type <- recode(results$explanation_type,
#                                    "causal_true" = "causal_true",
#                                    "marginal" = "marginal",
#                                    "conditional" = "conditional",
#                                    "oracle_mec_sample" = "oracle_sample",
#                                    "oracle_mec_iw" = "oracle_iw",
#                                    "pc_1024n_alpha0.05_mec_sample" = "pc_sample",
#                                    "pc_1024n_alpha0.05_mec_iw" = "pc_iw",
#                                    "fges_1024_mec_sample" = "fges_sample",
#                                    "fges_1024_mec_iw" = "fges_iw")
#
# write_delim(results, 'R/experiments/results/gaussian/results_nodes.csv', delim = ',')