rm(list = ls())

source('R/experiments/analysis/metrics.R')
source('R/experiments/analysis/helper_functions.R')

library(tidyverse)

# Parameters

nmc <- 1000
nmc_causal_true <- 4000
nobs <- 1000
ncomb <- 1024
n_cd_values <- c(250, 500, 1000, 2000, 4000)

n_seeds <- 40
n_to_explain <- 40

n_nodes <- 11
n_parents_y <- 6
n_neighbors_values <- 2


# Initialize results tibble
results <- tibble(
  ncd = integer(),
  seed = integer(),
  explanation_type = character(),
  idx_mec = integer(),
  idx_datapoint = integer(),
  l2_error = double()
)

# Helper function to add results to the tibble
add_results <- function(results, ncd, seed, explanation_type, l2_errors, idx_mec = NA) {
  results <- results %>%
    bind_rows(tibble(
      ncd = rep(ncd, length(l2_errors)),
      seed = rep(seed, length(l2_errors)),
      explanation_type = rep(explanation_type, length(l2_errors)),
      idx_mec = rep(idx_mec, length(l2_errors)),
      idx_datapoint = seq_along(l2_errors),
      l2_error = l2_errors
    ))
  return(results)
}

# Total compute time
total_time <- 0

# Main processing loop
explanation_types <- c("causal_true", "marginal", "conditional", "oracle_mec_sample", "oracle_mec_iw",
                       "pc_iw", "pc_sample", "fges_iw", "fges_sample")

for (idx_n_cd in seq_along(n_cd_values)) {
  n_cd <- n_cd_values[idx_n_cd]

  dir <- file.path('experiments/6_explanations/xgboost/gaussian/11_nodes/2_neighbors_6_parents_y')


  for (seed in seq_len(n_seeds)) {

    explanations <- lapply(seq_along(explanation_types), function(explanation_idx) {
      current_explanation_type <- explanation_types[explanation_idx]
      current_nobs <- if (current_explanation_type == "causal_true") NULL else nobs
      current_nmc <- if (current_explanation_type == "causal_true") nmc_causal_true else nmc

      if (current_explanation_type == "pc_iw") {
        current_explanation_type <- paste0("pc_", n_cd, "n_alpha0.05_standardized_mec_iw")
      } else if (current_explanation_type == "pc_sample") {
        current_explanation_type <- paste0("pc_", n_cd, "n_alpha0.05_standardized_mec_sample")
      } else if (current_explanation_type == "fges_iw") {
        current_explanation_type <- paste0("fges_", n_cd, "_standardized_mec_iw")
      } else if (current_explanation_type == "fges_sample") {
        current_explanation_type <- paste0("fges_", n_cd, "_standardized_mec_sample")
      }

      return(retrieve_explanations(dir, seed, current_explanation_type, current_nobs, current_nmc, ncomb,
                                                   n_to_explain, return_shapley_values = TRUE, use_first_one = TRUE))
    })

    explanation_true <- explanations[[1]]
    if (is.null(explanation_true)) {
      print(paste0("Missing true causal explanation for ", n_cd, " ncd, seed ", seed))
      next
    }
    total_time <- total_time + explanation_true$duration

    if (all(sapply(explanations, Negate(is.null)))) {
      explanation_true <- explanations[[1]]
      for (i in 2:length(explanations)) {
        explanation <- explanations[[i]]
        explanation_type <- explanation_types[i]
        if (is.data.frame(explanation)) {  # Non-list case
          l2_errors <- compute_normalized_l2(explanation_true, explanation)
          results <- add_results(results, n_cd, seed, explanation_type, l2_errors)
        } else {  # List case for multiple explanations
          for (idx_mec in seq_along(explanation)) {
            l2_errors <- compute_normalized_l2(explanation_true, explanation[[idx_mec]])
            results <- add_results(results, n_cd, seed, explanation_type, l2_errors, idx_mec)
          }
        }
      }
    } else {
      print(paste0("Missing one or more explanations for ", n_cd, " ncd, seed ", seed))
    }
  }
}

# rename explanation_types
results$explanation_type <- recode(results$explanation_type,
                                   "causal_true" = "causal_true",
                                   "marginal" = "marginal",
                                   "conditional" = "conditional",
                                   "oracle_mec_sample" = "oracle_sample",
                                   "oracle_mec_iw" = "oracle_iw",
                                   "pc_iw" = "pc_iw",
                                   "pc_sample" = "pc_sample",
                                   "fges_iw" = "fges_iw",
                                   "fges_sample" = "fges_sample")
)

write_delim(results, 'R/experiments/results/gaussian/results_ncd.csv', delim = ',')
print("Done")
#print(paste0("Total time: ", as.numeric(total_time, unit="days"), " days")