rm(list = ls())

source('R/experiments/analysis/metrics.R')
source('R/experiments/analysis/helper_functions.R')

library(tidyverse)

# Parameters
nmc <- 1000
nmc_causal_true <- 4000
nobs <- 1000

n_seeds <- 40
n_to_explain <- 40

n_nodes_values <- c(6, 11, 16)
n_combinations_values <- c(32, 1024, 4096)
n_combinations_causal_true_values <- c(32, 1024, 8192)
dag_types <- c('2_neighbors_3_parents_y', '2_neighbors_6_parents_y', '2_neighbors_9_parents_y')

# Initialize results tibble
results <- tibble(
  n_nodes = integer(),
  seed = integer(),
  explanation_type = character(),
  idx_mec = integer(),
  idx_datapoint = integer(),
  l2_error = double()
)

# Helper function to add results to the tibble
add_results <- function(results, n_nodes, seed, explanation_type, l2_errors, idx_mec = NA) {
  results <- results %>%
    bind_rows(tibble(
      n_nodes = rep(n_nodes, length(l2_errors)),
      seed = rep(seed, length(l2_errors)),
      explanation_type = rep(explanation_type, length(l2_errors)),
      idx_mec = rep(idx_mec, length(l2_errors)),
      idx_datapoint = seq_along(l2_errors),
      l2_error = l2_errors
    ))
    return(results)
}

# Main processing loop
explanation_types <- c("causal_true", "marginal", "conditional", "oracle_mec_sample", "oracle_mec_iw",
                        "pc_1000n_alpha0.05_standardized_mec_sample", "pc_1000n_alpha0.05_standardized_mec_iw",
                        "fges_1000_standardized_mec_sample", "fges_1000_standardized_mec_iw")

for (idx_n_nodes in seq_along(n_nodes_values)) {
  n_nodes <- n_nodes_values[idx_n_nodes]
  n_combinations <- n_combinations_values[idx_n_nodes]
  n_combinations_causal_true <- n_combinations_causal_true_values[idx_n_nodes]

  dag_type <- dag_types[idx_n_nodes]
  dir <- file.path('experiments/6_explanations/xgboost/non-linear1/', paste0(n_nodes, '_nodes'), dag_type)

  nmc_explanations <- ifelse(explanation_types == "causal_true", nmc_causal_true, nmc)
  n_combinations_explanations <- ifelse(explanation_types == "causal_true", n_combinations_causal_true, n_combinations)
  #n_obs_values <- ifelse(explanation_types == "causal_true", NULL, nobs) # this does not work

  for (seed in seq_len(n_seeds)) {

    explanations <- lapply(seq_along(explanation_types), function(explanation_idx) {
      current_explanation_type <- explanation_types[explanation_idx]
      current_nmc <- nmc_explanations[explanation_idx]
      current_n_combinations <- n_combinations_explanations[explanation_idx]
      current_nobs <- if (current_explanation_type == "causal_true") NULL else nobs

      return(retrieve_explanations(dir, seed, current_explanation_type, current_nobs, current_nmc, current_n_combinations,
                                                   n_to_explain, return_shapley_values = TRUE, use_first_one = TRUE))
    })

    explanation_true <- explanations[[1]]
    if (is.null(explanation_true)) {
      print(paste0("Missing true causal explanation for ", n_nodes, " nodes, seed ", seed))
      next
    }

    if (all(sapply(explanations, Negate(is.null)))) {
      explanation_true <- explanations[[1]]
      for (i in 2:length(explanations)) {
        explanation <- explanations[[i]]
        explanation_type <- explanation_types[i]
        if (is.data.frame(explanation)) {  # Non-list case
          l2_errors <- compute_normalized_l2(explanation_true, explanation)
          results <- add_results(results, n_nodes, seed, explanation_type, l2_errors)
        } else {  # List case for multiple explanations
          for (idx_mec in seq_along(explanation)) {
            l2_errors <- compute_normalized_l2(explanation_true, explanation[[idx_mec]])
            results <- add_results(results, n_nodes, seed, explanation_type, l2_errors, idx_mec)
          }
        }
      }
    } else {
      print(paste0("Missing one or more explanations for ", n_nodes, " nodes, seed ", seed))
    }
  }
}

# rename explanation_types
results$explanation_type <- recode(results$explanation_type,
                                   "causal_true" = "causal_true",
                                   "marginal" = "marginal",
                                   "conditional" = "conditional",
                                   "oracle_mec_sample" = "oracle_sample",
                                   "oracle_mec_iw" = "oracle_iw",
                                   "pc_1000n_alpha0.05_standardized_mec_sample" = "pc_sample",
                                   "pc_1000n_alpha0.05_standardized_mec_iw" = "pc_iw",
                                   "fges_1000_standardized_mec_sample" = "fges_sample",
                                   "fges_1000_standardized_mec_iw" = "fges_iw")

write_delim(results, 'R/experiments/results/non-linear1/results_nodes.csv', delim = ',')
print('Done')