rm(list = ls())

source('R/experiments/analysis/metrics.R')
source('R/experiments/analysis/helper_functions.R')

library(tidyverse)

# Parameters
nmc <- 1000
nmc_causal_true <- 4000
nobs <- 1000
ncomb <- 1024

n_seeds <- 40
n_to_explain <- 40

n_nodes <- 11
n_parents_y <- 6
n_neighbors_values <- c(1,2, 3)


# Initialize results tibble
results <- tibble(
  n_neighbors = integer(),
  seed = integer(),
  explanation_type = character(),
  idx_mec = integer(),
  idx_datapoint = integer(),
  l2_error = double()
)

# Helper function to add results to the tibble
add_results <- function(results, n_neighbors, seed, explanation_type, l2_errors, idx_mec = NA) {
  results <- results %>%
    bind_rows(tibble(
      n_neighbors = rep(n_neighbors, length(l2_errors)),
      seed = rep(seed, length(l2_errors)),
      explanation_type = rep(explanation_type, length(l2_errors)),
      idx_mec = rep(idx_mec, length(l2_errors)),
      idx_datapoint = seq_along(l2_errors),
      l2_error = l2_errors
    ))
  return(results)
}

# Total compute time
total_time <- 0

# Main processing loop
explanation_types <- c("causal_true", "marginal", "conditional", "oracle_mec_sample", "oracle_mec_iw",
                        "pc_1000n_alpha0.05_standardized_mec_sample", "pc_1000n_alpha0.05_standardized_mec_iw",
                        "fges_1000_standardized_mec_sample", "fges_1000_standardized_mec_iw")


nmc_explanations <- ifelse(explanation_types == "causal_true", nmc_causal_true, nmc)
for (idx_n_neighbors in seq_along(n_neighbors_values)) {
  n_neighbors <- n_neighbors_values[idx_n_neighbors]

  dir <- file.path('experiments/6_explanations/xgboost/gaussian/11_nodes', paste0(n_neighbors, '_neighbors_6_parents_y'))

  #n_obs_values <- ifelse(explanation_types == "causal_true", NULL, nobs) # this does not work

  for (seed in seq_len(n_seeds)) {

    explanations <- lapply(seq_along(explanation_types), function(explanation_idx) {
      current_explanation_type <- explanation_types[explanation_idx]
      current_nmc <- nmc_explanations[explanation_idx]
      current_nobs <- if (current_explanation_type == "causal_true") NULL else nobs

      return(retrieve_explanations(dir, seed, current_explanation_type, current_nobs, current_nmc, ncomb,
                                                   n_to_explain, return_shapley_values = TRUE, use_first_one = TRUE))
    })

    explanation_true <- explanations[[1]]
    if (is.null(explanation_true)) {
      print(paste0("Missing true causal explanation for ", n_neighbors, " neighbors, seed ", seed))
      next
    }
    total_time <- total_time + explanation_true$duration

    if (all(sapply(explanations, Negate(is.null)))) {
      explanation_true <- explanations[[1]]
      for (i in 2:length(explanations)) {
        explanation <- explanations[[i]]
        explanation_type <- explanation_types[i]
        if (is.data.frame(explanation)) {  # Non-list case
          l2_errors <- compute_normalized_l2(explanation_true, explanation)
          results <- add_results(results, n_neighbors, seed, explanation_type, l2_errors)
        } else {  # List case for multiple explanations
          for (idx_mec in seq_along(explanation)) {
            l2_errors <- compute_normalized_l2(explanation_true, explanation[[idx_mec]])
            results <- add_results(results, n_neighbors, seed, explanation_type, l2_errors, idx_mec)
          }
        }
      }
    } else {
      print(paste0("Missing one or more explanations for ", n_neighbors, " neighbors, seed ", seed))
    }
  }
}

# rename explanation_types
results$explanation_type <- recode(results$explanation_type,
                                   "causal_true" = "causal_true",
                                   "marginal" = "marginal",
                                   "conditional" = "conditional",
                                   "oracle_mec_sample" = "oracle_sample",
                                   "oracle_mec_iw" = "oracle_iw",
                                   "pc_1000n_alpha0.05_standardized_mec_sample" = "pc_sample",
                                   "pc_1000n_alpha0.05_standardized_mec_iw" = "pc_iw",
                                   "fges_1000_standardized_mec_sample" = "fges_sample",
                                   "fges_1000_standardized_mec_iw" = "fges_iw")

write_delim(results, 'R/experiments/results/gaussian/results_neighbors.csv', delim = ',')
print('Done')