rm(list = ls())

source('R/experiments/analysis/metrics.R')
source('R/experiments/analysis/helper_functions.R')

library(tidyverse)

# Parameters

nmc <- 1000
nmc_causal_true <- 4000
nobs <- 1000
ncombs_values <- c(64, 128, 256, 512, 1024)

n_seeds <- 40
n_to_explain <- 40

n_nodes <- 11
n_parents_y <- 6
n_neighbors_values <- 2


# Initialize results tibble
results <- tibble(
  ncombs = integer(),
  seed = integer(),
  explanation_type = character(),
  idx_mec = integer(),
  idx_datapoint = integer(),
  l2_error = double()
)

# Helper function to add results to the tibble
add_results <- function(results, ncombs, seed, explanation_type, l2_errors, idx_mec = NA) {
  results <- results %>%
    bind_rows(tibble(
      ncombs = rep(ncombs, length(l2_errors)),
      seed = rep(seed, length(l2_errors)),
      explanation_type = rep(explanation_type, length(l2_errors)),
      idx_mec = rep(idx_mec, length(l2_errors)),
      idx_datapoint = seq_along(l2_errors),
      l2_error = l2_errors
    ))
  return(results)
}

# Total compute time
total_time <- 0


# Main processing loop
explanation_types <- c("causal_true", "causal_true_subsampled", "marginal", "conditional",
                       "oracle_mec_sample", "oracle_mec_iw",
                       "pc_1000n_alpha0.05_standardized_mec_sample", "pc_1000n_alpha0.05_standardized_mec_iw",
                       "fges_1000_standardized_mec_sample", "fges_1000_standardized_mec_iw")
for (idx_ncombs in seq_along(ncombs_values)) {
  ncombs <- ncombs_values[idx_ncombs]

  dir <- file.path('experiments/6_explanations/xgboost/non-linear1/11_nodes/2_neighbors_6_parents_y')


  for (seed in seq_len(n_seeds)) {

    explanations <- lapply(seq_along(explanation_types), function(explanation_idx) {
      current_explanation_type <- explanation_types[explanation_idx]
      current_nobs <- if (grepl("causal_true", current_explanation_type)) NULL else nobs
      current_nmc <- if (grepl("causal_true", current_explanation_type)) nmc_causal_true else nmc

      if (current_explanation_type == "causal_true") {
        return(retrieve_explanations(dir, seed, current_explanation_type, current_nobs, current_nmc, 1024,
                                     n_to_explain, return_shapley_values = TRUE, use_first_one = TRUE))
      } else if (current_explanation_type == "causal_true_subsampled") {
        return(retrieve_explanations(dir, seed, "causal_true", current_nobs, current_nmc, ncombs,
                                     n_to_explain, return_shapley_values = TRUE, use_first_one = TRUE))
      }
      else {
        return(retrieve_explanations(dir, seed, current_explanation_type, current_nobs, current_nmc, ncombs,
                                     n_to_explain, return_shapley_values = TRUE, use_first_one = TRUE))
      }
    })

    explanation_true <- explanations[[1]]
    if (is.null(explanation_true)) {
      print(paste0("Missing true causal explanation for ", ncombs, " ncombs, seed ", seed))
      next
    }
    total_time <- total_time + explanation_true$duration

    if (all(sapply(explanations, Negate(is.null)))) {
      explanation_true <- explanations[[1]]
      for (i in 2:length(explanations)) {
        explanation <- explanations[[i]]
        explanation_type <- explanation_types[i]
        if (is.data.frame(explanation)) {  # Non-list case
          l2_errors <- compute_normalized_l2(explanation_true, explanation)
          results <- add_results(results, ncombs, seed, explanation_type, l2_errors)
        } else {  # List case for multiple explanations
          for (idx_mec in seq_along(explanation)) {
            l2_errors <- compute_normalized_l2(explanation_true, explanation[[idx_mec]])
            results <- add_results(results, ncombs, seed, explanation_type, l2_errors, idx_mec)
          }
        }
      }
    } else {
      print(paste0("Missing one or more explanations for ", ncombs, " ncombs, seed ", seed))
    }
  }
}

# rename explanation_types
results$explanation_type <- recode(results$explanation_type,
                                   "causal_true" = "causal_true",
                                   "causal_true_subsampled" = "causal_true_subsampled",
                                   "marginal" = "marginal",
                                   "conditional" = "conditional",
                                   "oracle_mec_sample" = "oracle_sample",
                                   "oracle_mec_iw" = "oracle_iw",
                                   "pc_1000n_alpha0.05_standardized_mec_sample" = "pc_sample",
                                   "pc_1000n_alpha0.05_standardized_mec_iw" = "pc_iw",
                                   "fges_1000_standardized_mec_sample" = "fges_sample",
                                   "fges_1000_standardized_mec_iw" = "fges_iw")

write_delim(results, 'R/experiments/results/non-linear1/results_ncombs.csv', delim = ',')
print('Done')